Source code for zhusuan.distributions.bernoulli

import jittor as jt
import numpy as np

from zhusuan.distributions.base import Distribution


[docs]class Bernoulli(Distribution): """ The class of univariate Bernoulli distribution. See :class:`~zhusuan.distributions.base.Distribution` for details. :param probs: A `float` Var. The log-odds of probabilities of being 1. .. math:: \\mathrm{logits} = \\log \\frac{p}{1 - p} :param dtype: The value type of samples from the distribution. Can be int (`int16`, `int32`, `int64`) or float (`float16`, `float32`, `float64`). Default is `float32`. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. """ def __init__(self, dtype='float32', param_dtype='float32', is_continues=False, is_reparameterized=True, group_ndims=0, **kwargs): super(Bernoulli, self).__init__(dtype, param_dtype, is_continues, is_reparameterized, group_ndims=group_ndims, **kwargs) self._probs = kwargs['probs'] self._probs = jt.cast(self._probs, self._dtype) @property def probs(self): return self._probs def _batch_shape(self): return self.probs.shape def _sample(self, n_samples=1, **kwargs): if n_samples > 1: sample_shape = np.concatenate([[n_samples], self.batch_shape], axis=0).tolist() _probs = self._probs * jt.ones(sample_shape) else: _probs = self._probs _probs *= jt.cast(_probs <= 1, self._dtype) _sample = jt.bernoulli(_probs) _sample = jt.cast(_sample, self._dtype) self.sample_cache = _sample return _sample def _log_prob(self, sample=None): if sample is None: sample = self.sample_cache if len(sample.shape) > len(self._probs.shape): sample_shape = np.concatenate([[sample.shape[0]], self.batch_shape], axis=0).tolist() _probs = self._probs * jt.ones(sample_shape) else: _probs = self._probs log_prob = sample * jt.log(_probs + 1e-8) + (1 - sample) * jt.log(1 - _probs + 1e-8) return log_prob