Source code for zhusuan.mcmc.SGMCMC

import jittor as jt
from jittor import Module


[docs]class SGMCMC(Module): """ Base class for stochastic gradient MCMC (SGMCMC) algorithms. SGMCMC is a class of MCMC algorithms which utilize stochastic gradients instead of the true gradients. The typical code for SGMCMC inference is like:: sgmcmc = zs.mcmc.SGLD(learning_rate=lr) net = BayesianNet() for epoch in range(epoch_size): for step in range(num_batches): w_samples = model.sample(net, {'x': x, 'y': y}) for i, (k, w) in enumerate(w_samples.items()): # Utilize stochastic gradients by samples and update parameters. ... """ def __init__(self): super().__init__() self.t = 0 def _update(self, bn, observed): raise NotImplementedError()
[docs] def execute(self, bn, observed, resample=False, step=1): if resample: self.t = 0 bn.execute(observed) self.t += 1 self._latent = {k: v.tensor for k, v in bn.nodes.items() if k not in observed.keys()} self._latent_k = self._latent.keys() self._var_list = [self._latent[k] for k in self._latent_k] sample_ = dict(zip(self._latent_k, self._var_list)) for i in range(len(self._var_list)): self._var_list[i] = self._var_list[i].detach() return sample_ for s in range(step): self._update(bn, observed) self.t += 1 sample_ = dict(zip(self._latent_k, self._var_list)) return sample_
[docs] def initialize(self): self.t = 0
[docs] def sample(self, bn, observed, resample=False, step=1): """ Running one sgmcmc iteration. :param bn: A instance of :class:`~zhusuan.framework.bn.BayesianNet`. :param observed: A dictionary of ``(string, Tensor)`` pairs. Mapping from names of observed `StochasticTensor` s to their values. :param resample: Flag indicates if the sampler need get the var list of the :class:`~zhusuan.framework.bn.BayesianNet` instance, usually set to True on first sgmcmc iteration. :return: A list of Var, samples generated by sgmcmc iteration. """ return self.execute(bn, observed, resample, step)