Source code for zhusuan.transforms.base
import jittor as jt
from jittor import Module
[docs]class Transform(Module):
"""
Base class for Transforms.
"""
def __init__(self):
super().__init__()
self.is_invertible = True
def _forward(self, *args, **kwargs):
"""
Forward transform.
Compute :math:`x \mapsto z` and the log_abs determinant jacobian term.
"""
raise NotImplementedError()
def _inverse(self, *args, **kwargs):
"""
Inverse transform.
Compute :math:`z \mapsto x`.
"""
raise NotImplementedError()
[docs] def execute(self, *args, inverse=False, **kwargs):
"""
Do forward and inverse transform.
* **Forward transform**: Compute :math:`x \mapsto z` and the log_abs determinant jacobian term.
* **Inverse transform**: Compute :math:`z \mapsto x`.
:param inverse: A Bool. Indicates whether execute the forward transform or the inverse transform.
"""
if not inverse:
return self._forward(*args, **kwargs)
else:
return self._inverse(*args, **kwargs)