Source code for zhusuan.transforms.sequential

import jittor as jt
from jittor import nn

from zhusuan.transforms.base import Transform

[docs]class Sequential(Transform): """ A container which stacks transforms together and do forward or inverse transform sequentially. :param modules: A list of :class:`~zhusuan.transforms.base.Transform`. """ def __init__(self, modules): super().__init__() self.modules = nn.Sequential(modules) def _forward(self, *x, **kwargs): """ :param *x: Arbitrarily number of variables which to be transformed. :param **kwargs: Auxiliary parameters. :return: A tuple, the first term is the transformed Var and the second term is the sum of the log_abs determinant term of each transform's Jacobian matrix. """ log_detJ = [] for i in range(len(self.modules)): x = self.modules[i](*x, inverse=False, **kwargs) assert isinstance(x, tuple) assert len(x) >= 2 if x[-1] is not None: log_detJ.append(x[-1]) if isinstance(x[0], tuple): x = x[0] else: x = x[:len(x) - 1] assert isinstance(x, tuple) return x, sum(log_detJ) if log_detJ else jt.zeros([1]) def _inverse(self, *z, **kwargs): """ :param *x: Arbitrarily number of variables which to be transformed. :param **kwargs: Auxiliary parameters. :return: The transformed Var. """ # No log_det(jacobian) in inverse process for i in reversed(range(len(self.modules))): z = self.modules[i](*z, inverse=True, **kwargs) assert isinstance(z, tuple) assert len(z) >= 2 assert z[-1] is None if isinstance(z[0], tuple): z = z[0] else: z = z[:len(z) - 1] assert isinstance(z, tuple) return z