Source code for zhusuan.framework.bn

import jittor as jt
from jittor import Module

from zhusuan.framework.stochastic_tensor import StochasticTensor
from zhusuan.distributions import *
from zhusuan.flows import *


[docs]class BayesianNet(Module): def __init__(self, observed=None): """ The :class:`BayesianNet` class provides a convenient way to construct Bayesian networks, i.e., directed graphical models. To start, we create a :class:`BayesianNet` instance:: class Net(BayesianNet): def __init__(self): # Initialize... def execute(self, observed): # Forward propagation... A :class:`BayesianNet` keeps two kinds of nodes * deterministic nodes: they are just Vars, usually the outputs of Jittor operations. * stochastic nodes: they are random variables in graphical models, and can be constructed inside the ``BayesianNet`` class like :: self.stochastic_node('Normal', name='w', mean=0., std=1.) To observe any stochastic nodes in the network, pass a dictionary mapping of ``(name, Tensor)`` pairs when call the model which derives :class:`BayesianNet`. This will assign observed values to corresponding :class:`StochasticTensor` s. For example:: net = Net() net({'w': w_obs}) will set ``w`` to be observed. .. note:: The observation passed must have the same type and shape as the :class:`StochasticTensor`. .. seealso:: For more details and examples, please refer to :doc:`/tutorials/concepts`. :param observed: A dictionary of (string, Tensor) pairs, which maps from names of stochastic nodes to their observed values. """ super(BayesianNet, self).__init__() self._nodes = {} self._cache = {} self._observed = observed if observed else {} @property def nodes(self): """ The dictionary of all named stochastic nodes in this :class:`BayesianNet`. :return: A dict. """ return self._nodes @property def cache(self): """ The dictionary of all named deterministic nodes in this :class:`BayesianNet`. :return: A dict. """ return self._cache @property def observed(self): """ The dictionary of all observed nodes in this :class:`BayesianNet`. :return: A dict. """ return self._observed
[docs] def observe(self, observed): """ Assign the nodes and values to be observed in this :class:`BayesianNet`. :param observed: A dictionary of (string, Tensor) pairs, which maps from names of stochastic nodes to their observed values. """ self._observed = {} for k, v in observed.items(): self._observed[k] = v return self
[docs] def sn(self, *args, **kwargs): """ Short cut for method :meth:`~zhusuan.framework.bn.BayesianNet.stochastic_node` """ return self.stochastic_node(*args, **kwargs)
[docs] def stochastic_node(self, distribution, name, **kwargs): """ Add a stochastic node in this :class:`BayesianNet` that follows the distribution assigned by the `name` parameter. :param distribution: The distribution which the node follows. :param name: The unique name of the node. :param **kwargs: Parameters of the distribution which the node builds with. :return: A instance(sample) of the node. """ _dist = globals()[distribution](**kwargs) self._nodes[name] = StochasticTensor(self, name, _dist, **kwargs) return self._nodes[name].tensor
def _log_joint(self): _ret = 0 for k, v in self._nodes.items(): if isinstance(v, StochasticTensor): try: _ret = _ret + v.log_prob() except: _ret = v.log_prob() return _ret
[docs] def log_joint(self, use_cache=False): """ The default log joint probability of this :class:`BayesianNet`. It works by summing over all the conditional log probabilities of stochastic nodes evaluated at their current values (samples or observations). :return: A Var. """ if use_cache: if not hasattr(self, '_log_joint_cache'): self._log_joint_cache = self._log_joint() else: self._log_joint_cache = self._log_joint() return self._log_joint_cache