import torch
from torch._C import device
import torch.nn as nn
from zhusuan.framework.stochastic_tensor import StochasticTensor
from zhusuan.distributions import *
name_mapping = {
"Normal": Normal,
"Bernoulli": Bernoulli,
"Beta": Beta,
"Exponential": Exponential,
"Gamma": Gamma,
"Laplace": Laplace,
"Logistic": Logistic,
"Poisson": Poisson,
"StudentT": StudentT,
"Uniform": Uniform
}
[docs]class BayesianNet(nn.Module):
def __init__(self, observed=None, device=torch.device('cpu')):
"""
The :class:`BayesianNet` class provides a convenient way to construct
Bayesian networks, i.e., directed graphical models.
To start, we create a :class:`BayesianNet` instance::
class Net(BayesianNet):
def __init__(self):
# Initialize...
def forward(self, observed):
# Forward propagation...
A :class:`BayesianNet` keeps two kinds of nodes
* deterministic nodes: they are just Vars, usually the outputs of Jittor operations.
* stochastic nodes: they are random variables in graphical models, and can be constructed inside the ``BayesianNet`` class like
::
self.stochastic_node('Normal', name='w', mean=0., std=1.)
To observe any stochastic nodes in the network, pass a dictionary mapping
of ``(name, Tensor)`` pairs when call the model which derives :class:`BayesianNet`.
This will assign observed values to corresponding
:class:`StochasticTensor` s. For example::
net = Net()
net({'w': w_obs})
will set ``w`` to be observed.
.. note::
The observation passed must have the same type and shape as the
:class:`StochasticTensor`.
.. seealso::
For more details and examples, please refer to
:doc:`/tutorials/concepts`.
:param observed: A dictionary of (string, Tensor) pairs, which maps from
names of stochastic nodes to their observed values.
"""
super(BayesianNet, self).__init__()
self._nodes = {}
self._cache = {}
self._observed = observed if observed else {}
self._device = device
@property
def nodes(self):
"""
The dictionary of all named stochastic nodes in this :class:`BayesianNet`.
:return: A dict.
"""
return self._nodes
@property
def cache(self):
"""
The dictionary of all named deterministic nodes in this :class:`BayesianNet`.
:return: A dict.
"""
return self._cache
@property
def device(self):
"""
The device this module lies at.
:return: torch.device
"""
try:
return next(self.parameters()).device
except:
return self._device
[docs] def to(self, device):
self._device = device
return super().to(device)
@property
def observed(self):
"""
The dictionary of all observed nodes in this :class:`BayesianNet`.
:return: A dict.
"""
return self._observed
[docs] def observe(self, observed):
"""
Assign the nodes and values to be observed in this :class:`BayesianNet`.
:param observed: A dictionary of (string, Tensor) pairs, which maps from
names of stochastic nodes to their observed values.
"""
self._observed = {}
for k, v in observed.items():
self._observed[k] = v
return self
[docs] def sn(self, dist, name, n_samples=None, **kwargs):
"""
Short cut for method :meth:`~zhusuan.framework.bn.BayesianNet.stochastic_node`
"""
return self.stochastic_node(dist, name, n_samples, **kwargs)
[docs] def snode(self, *args, **kwargs):
"""
Short cut for method :meth:`~zhusuan.framework.bn.BayesianNet.stochastic_node`
"""
return self.stochastic_node(*args, **kwargs)
[docs] def stochastic_node(self, distribution, name, n_samples=None, **kwargs):
"""
Add a stochastic node in this :class:`BayesianNet` that follows the distribution assigned by the
``name`` parameter.
:param distribution: The distribution which the node follows.
:param name: The unique name of the node.
:param n_samples: number of samples per sample process
:param kwargs: Parameters of the distribution which the node builds with.
:return: A instance(sample) of the node.
"""
if isinstance(distribution, str):
_dist = name_mapping[distribution](device=self.device, **kwargs)
self._nodes[name] = StochasticTensor(self, name, _dist, n_samples=n_samples, **kwargs)
elif isinstance(distribution, Distribution):
distribution._device = self.device
self._nodes[name] = StochasticTensor(self, name, distribution, n_samples=n_samples, **kwargs)
else:
raise ValueError('distribution must be name of sub class of Distribution or an instance of Distribution')
return self._nodes[name].tensor
def _log_joint(self):
_ret = 0
for k, v in self._nodes.items():
if isinstance(v, StochasticTensor):
try:
_ret = _ret + v.log_prob()
except:
_ret = v.log_prob()
return _ret
[docs] def log_joint(self, use_cache=False):
"""
The default log joint probability of this :class:`BayesianNet`.
It works by summing over all the conditional log probabilities of
stochastic nodes evaluated at their current values (samples or
observations).
:return: A Var.
"""
# TODO: Why summing over?
if use_cache:
if not hasattr(self, '_log_joint_cache'):
self._log_joint_cache = self._log_joint()
else:
self._log_joint_cache = self._log_joint()
return self._log_joint_cache
# aliases of specific distribution
[docs] def normal(self,
name,
mean=0.,
std=None,
logstd=None,
dtype=None,
is_continuous=True,
is_reparameterized=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Normal(
mean=mean,
std=std,
logstd=logstd,
dtype=dtype,
is_continuous=is_continuous,
is_reparameterized=is_reparameterized,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def bernoulli(self,
name,
logits=None,
probs=None,
dtype=None,
is_continuous=False,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Bernoulli(
logits=logits,
probs=probs,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def beta(self, name,
alpha,
beta,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Beta(
alpha=alpha,
beta=beta,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def exponential(self, name,
rate,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Exponential(
rate=rate,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def gamma(self, name,
alpha,
beta,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Gamma(
alpha=alpha,
beta=beta,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def laplace(self, name,
loc,
scale,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Laplace(
loc=loc,
scale=scale,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def logistic(self, name,
loc,
scale,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Laplace(
loc=loc,
scale=scale,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def poisson(self, name,
rate,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = Poisson(
rate=rate,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")
[docs] def studentT(self, name,
df,
loc=0.,
scale=1.,
dtype=None,
is_continuous=True,
group_ndims=0,
n_samples=None,
**kwargs):
if isinstance(name, str):
distribution = StudentT(
df=df,
loc=loc,
scale=scale,
dtype=dtype,
is_continuous=is_continuous,
group_ndims=group_ndims,
device=self.device,
**kwargs
)
self._nodes[name] = StochasticTensor(self, name, distribution,
n_samples=n_samples, **kwargs)
return self._nodes[name].tensor
else:
raise ValueError("name of stochastic_node must be str")