Source code for zhusuan.mcmc.SGHMC

import math

import torch

from zhusuan.mcmc.SGMCMC import SGMCMC

__all__ = [
    "SGHMC",
]


[docs]class SGHMC(SGMCMC): def __init__(self, learning_rate, friction=0.25, variance_estimate=0., n_iter_resample_v=20, second_order=True): super(SGHMC, self).__init__() self.lr = learning_rate self.alpha = friction self.beta = variance_estimate if n_iter_resample_v is None: n_iter_resample_v = 0 self.n_iter_resample_v = n_iter_resample_v self.second_order = second_order self.vs = None # velocity def _update(self, bn, observed): if not self.vs: self.vs = [torch.normal(mean=0., std=math.sqrt(self.lr), size=q.shape) for q in self._var_list] gaussian_term = None for i, _ in enumerate(self.vs): if self.n_iter_resample_v != 0: if self.t % self.n_iter_resample_v == 0: self.vs[i] = torch.normal(mean=0., std=math.sqrt(self.lr), size=self.vs[i].shape) gaussian_term = torch.normal(mean=0., std=math.sqrt(2 * (self.alpha - self.beta) * self.lr), size=self.vs[i].shape) if self.second_order: self._var_list[i] = self._var_list[i] + 0.5 * self.vs[i] self._var_list[i] = self._var_list[i].detach() self._var_list[i].requires_grad = True observed_ = {**dict(zip(self._latent_k, self._var_list)), **observed} bn.forward(observed_) log_joint_ = bn.log_joint() grad = torch.autograd.grad(log_joint_, self._var_list) for i, _ in enumerate(grad): if not self.second_order: self.vs[i] = (1 - self.alpha) * self.vs[i] + self.lr * grad[i] + gaussian_term self._var_list[i] = self._var_list[i] + self.vs[i] self._var_list[i] = self._var_list[i].detach() self._var_list[i].requires_grad = True else: decay_half = math.exp(-0.5 * self.alpha) self.vs[i] = decay_half * (decay_half * self.vs[i] + self.lr * grad[i] + gaussian_term) self._var_list[i] = self._var_list[i] + 0.5 * self.vs[i] self._var_list[i] = self._var_list[i].detach() self._var_list[i].requires_grad = True