zhusuan.framework.stochastic_tensor

class StochasticTensor(bn, name, dist, observation=None, **kwargs)[source]

Bases: object

The StochasticTensor class represents the stochastic nodes in a BayesianNet. We can use any distribution available in zhusuan.distributions to construct a stochastic node in a BayesianNet. For example:

class Net(BayesianNet):
    def __init__(self):
        self.stochastic_node('Normal', name='x', mean=0., std=1.)

will build a stochastic node in Net with the Normal distribution. The returned x will be a instance of StochasticTensor.

StochasticTensor instances are Vars, which means that they can be passed into any Jittor operations. This makes it easy to build Bayesian networks by mixing stochastic nodes and Jittor primitives.

See also

For more information, please refer to Basic Concepts in ZhuSuan.

Parameters
  • bn – A BayesianNet.

  • name – A string. The name of the StochasticTensor. Must be unique in a BayesianNet.

  • dist – A Distribution instance that determines the distribution used in this stochastic node.

  • observation – A Var, which matches the shape of dist. If specified, then the StochasticTensor is observed and the tensor property will return the observation.

  • n_samples – A 0-D integer. Number of samples generated by this StochasticTensor.

property bn

The BayesianNet where the StochasticTensor lives.

Returns

A BayesianNet instance.

property dist

The distribution followed by the StochasticTensor.

Returns

A Distribution instance.

property dtype

The sample type of the StochasticTensor.

Returns

A DType instance.

is_observed()[source]

Whether the StochasticTensor is observed or not.

Returns

A bool.

log_prob(sample=None)[source]
property name

The name of the StochasticTensor.

Returns

A string.

property shape

Return the static shape of this StochasticTensor.

Returns

A jittor_core.NanoVector instance.

property tensor

The value of this StochasticTensor. If it is observed, then the observation is returned, otherwise samples are returned.

Returns

A Var.