Source code for zhusuan.tests.transforms.base

import jittor as jt
import unittest
import numpy as np

[docs]class TestInvertibleTransform(unittest.TestCase):
[docs] def assert_invertible(self, *inputs, transform=None, decimal=7): z, log_det = transform.execute(*inputs, inverse=False) if not isinstance(z, tuple): xr, _ = transform.execute(z, inverse=True) np.testing.assert_almost_equal(inputs[0].numpy(), xr.numpy(), decimal=decimal) else: xr = transform.execute(*z, inverse=True) for i, _x in enumerate(z): np.testing.assert_almost_equal(inputs[i].numpy(), xr[i].numpy(), decimal=decimal)