Source code for zhusuan.mcmc.SGLD

import jittor as jt

import math

from zhusuan.mcmc.SGMCMC import SGMCMC


[docs]class SGLD(SGMCMC): """ Subclass of SGMCMC which implements Stochastic Gradient Langevin Dynamics (Welling & Teh, 2011) (SGLD) update. The updating equation implemented below follows Equation (3) in the paper. * **var_list** - The updated values of latent variables. :param learning_rate: A 0-D `float32` Var. """ def __init__(self, learning_rate): super().__init__() self.lr = jt.array(learning_rate) self.lr_min = jt.array(1e-4) def _update(self, bn, observed): observed_ = {**dict(zip(self._latent_k, self._var_list)), **observed} bn.execute(observed_) log_joint_ = bn.log_joint() grad = jt.grad(log_joint_, self._var_list) for i, _ in enumerate(grad): epsilon = jt.normal(0., math.sqrt(self.lr), size=self._var_list[i].shape) self._var_list[i] = self._var_list[i] + 0.5 * self.lr * grad[i] + epsilon self._var_list[i] = self._var_list[i].detach()