Source code for zhusuan.tests.transforms.invertible.test_scaling
import jittor as jt
import unittest
from zhusuan.tests.transforms import TestInvertibleTransform
from zhusuan.transforms.invertible import Scaling
[docs]class TestScaling(TestInvertibleTransform):
[docs] def test_invertible(self):
batch_size = 10
in_out_dim = 10
x = jt.randn([batch_size, in_out_dim])
t = Scaling(in_out_dim)
self.assert_invertible(x, transform=t)
if __name__ == '__main__':
unittest.main()