Source code for zhusuan.mcmc.SGLD

import torch

import math

from zhusuan.mcmc.SGMCMC import SGMCMC


[docs]class SGLD(SGMCMC): """ Subclass of SGMCMC which implements Stochastic Gradient Langevin Dynamics (Welling & Teh, 2011) (SGLD) update. The updating equation implemented below follows Equation (3) in the paper. * **var_list** - The updated values of latent variables. :param learning_rate: A 0-D `float32` Var. """ def __init__(self, learning_rate): super().__init__() self.lr = torch.as_tensor(learning_rate) self.lr_min = torch.as_tensor(1e-4) self._device = torch.device('cpu') #! NOTICE: device default as CPU @property def device(self): """ The device this module lies at. :return: torch.device """ try: return next(self.parameters()).device except: return self._device
[docs] def to(self, device): self._device = device return super().to(device)
def _update(self, bn, observed): observed_ = {**dict(zip(self._latent_k, self._var_list)), **observed} bn.forward(observed_) log_joint_ = bn.log_joint() grad = torch.autograd.grad(log_joint_, self._var_list) for i, _ in enumerate(grad): epsilon = torch.normal(0., math.sqrt(self.lr), size=self._var_list[i].shape).to(self.device) self._var_list[i] = self._var_list[i] + 0.5 * self.lr * grad[i] + epsilon self._var_list[i] = self._var_list[i].detach() self._var_list[i].requires_grad = True
[docs]class PSGLD(SGLD): """ PSGLD with RMSprop preconditioner, "Preconditioned stochastic gradient Langevin dynamics for deep neural networks" """ def __init__(self, learning_rate, decay=0.9, epsilon=1e-3): super().__init__(learning_rate) self.aux = None self.decay = decay self.epsilon = epsilon def _update(self, bn, observed): if not self.aux: self.aux = [torch.zeros_like(q) for q in self._var_list] observed_ = {**dict(zip(self._latent_k, self._var_list)), **observed} bn.forward(observed_) log_joint_ = bn.log_joint() grad = torch.autograd.grad(log_joint_, self._var_list) for i, _ in enumerate(grad): self.aux[i] = self.decay * self.aux[i] + (1 - self.decay) * torch.pow(grad[i], 2) g = 1 / (self.epsilon + torch.sqrt(self.aux[i])) e = torch.normal(0.0, torch.sqrt(self.lr * g)).to(self.device) self._var_list[i] = self._var_list[i] + 0.5 * self.lr * g * grad[i] + e self._var_list[i] = self._var_list[i].detach() self._var_list[i].requires_grad = True