diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 9ad89c5..aa85cfb 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -10,9 +10,6 @@ from .semirings import ( LogSemiring, MaxSemiring, - EntropySemiring, - CrossEntropySemiring, - KLDivergenceSemiring, MultiSampledSemiring, KMaxSemiring, StdSemiring, @@ -72,14 +69,25 @@ def log_prob(self, value): @lazy_property def entropy(self): - """ - Compute entropy for distribution :math:`H[z]`. + r""" + Compute entropy for distribution :math:`H[p]`. + + Algorithm derivation: + ..math:: + H[p] &= E_{p(z)}[-\log p(z)]\\ + &= -E_{p(z)}\big[ \log [\frac{1}{Z} \prod\limits_{c \in \mathcal{C}} \exp\{\phi_c(z_c)\}] \big]\\ + &= -E_{p(z)}\big[ \sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c) - \log Z \big]\\ + &= \log Z -E_{p(z)}\big[\sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c)\big]\\ + &= \log Z - \sum\limits_{c \in \mathcal{C}} p(z_c) \phi_{c}(z_c) Returns: entropy (*batch_shape*) """ - - return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths) + logZ = self.partition + p = self.marginals + phi = self.log_potentials + Hp = logZ - (p * phi).reshape(p.shape[0], -1).sum(-1) + return Hp def cross_entropy(self, other): """ @@ -91,10 +99,11 @@ def cross_entropy(self, other): Returns: cross entropy (*batch_shape*) """ - - return self._struct(CrossEntropySemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + logZ = other.partition + p = self.marginals + phi_q = other.log_potentials + Hq = logZ - (p * phi_q).reshape(p.shape[0], -1).sum(-1) + return Hq def kl(self, other): """ @@ -104,11 +113,15 @@ def kl(self, other): other : Comparison distribution Returns: - cross entropy (*batch_shape*) + kl divergence (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + logZp = self.partition + logZq = other.partition + p = self.marginals + phi_p = self.log_potentials + phi_q = other.log_potentials + KLpq = (p * (phi_p - phi_q)).reshape(p.shape[0], -1).sum(-1) - logZp + logZq + return KLpq @lazy_property def max(self): @@ -472,6 +485,23 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=False): super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args) self.multiroot = multiroot + def log_prob(self, value): + """ + Compute log probability over values :math:`p(z)`. + + Parameters: + value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*) + + Returns: + log_probs (*sample_shape x batch_shape*) + """ + s = value.shape + # assumes values do not have any 1s outside of the lengths + value_total_log_potentials = ( + (value * self.log_potentials.expand(s)).reshape(*s[:-2], -1).sum(-1) + ) + return value_total_log_potentials - self.partition + @lazy_property def marginals(self): """ @@ -502,7 +532,3 @@ def argmax(self): (Currently not implemented) """ pass - - @lazy_property - def entropy(self): - pass