Source code for zhusuan.transforms.invertible.scaling

import jittor as jt
from jittor import nn

from zhusuan.transforms.invertible import InvertibleTransform

[docs]class Scaling(InvertibleTransform): """ The scaling layer described in NICE paper :cite:`scaling-dinh2015nice`, which compute the following process and its inverse. .. math:: \\begin{bmatrix} S_1 & & & & S_2 & & & & \ddots & & & & S_D \end{bmatrix} \\begin{bmatrix} h_{i - 1, 1} h_{i - 1, 2} \\vdots h_{i - 1, D} \end{bmatrix} = \\begin{bmatrix} h_{i, 1} h_{i , 2} \\vdots h_{i, D} \end{bmatrix} :param n_dim: The dim of the Var to be transformed. .. rubric:: References .. bibliography:: ../refs.bib :style: unsrtalpha :keyprefix: scaling- """ def __init__(self, n_dim): super().__init__() self.log_scale = nn.init.constant(shape=[1, n_dim], dtype='float32') def _forward(self, x, **kwargs): log_detJ = self.log_scale.clone() x *= jt.exp(self.log_scale) return x, log_detJ def _inverse(self, z, **kwargs): z *= jt.exp(-self.log_scale) return z, None