Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Feb 27, 2021
2 parents 8029934 + 8e98839 commit bc3e889
Show file tree
Hide file tree
Showing 25 changed files with 510 additions and 334 deletions.
6 changes: 5 additions & 1 deletion disent/frameworks/ae/unsupervised/_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class AE(BaseFramework):
@dataclass
class cfg(BaseFramework.cfg):
recon_loss: str = 'mse'
loss_reduction: str = 'batch_mean'
# multiple reduction modes exist for the various loss components.
# - 'sum': sum over the entire batch
# - 'mean': mean over the entire batch
# - 'mean_sum': sum each observation, returning the mean sum over the batch
loss_reduction: str = 'mean'

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, batch_augment=batch_augment, cfg=cfg)
Expand Down
6 changes: 3 additions & 3 deletions disent/frameworks/helper/latent_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def params_to_distributions_and_sample(self, z_params: Params) -> Tuple[Tuple[Di
def compute_kl_loss(
cls,
posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None,
mode: str = 'direct', reduction='batch_mean'
mode: str = 'direct', reduction='mean'
):
"""
Compute the kl divergence
Expand Down Expand Up @@ -164,7 +164,7 @@ def params_to_distributions(self, z_params: Params) -> Tuple[Normal, Normal]:
return posterior, prior

@staticmethod
def LEGACY_compute_kl_loss(mu, logvar, mode: str = 'direct', reduction='batch_mean'):
def LEGACY_compute_kl_loss(mu, logvar, mode: str = 'direct', reduction='mean_sum'):
"""
Calculates the KL divergence between a normal distribution with
diagonal covariance and a unit normal distribution.
Expand All @@ -174,7 +174,7 @@ def LEGACY_compute_kl_loss(mu, logvar, mode: str = 'direct', reduction='batch_me
https://github.com/google-research/disentanglement_lib (compute_gaussian_kl)
"""
assert mode == 'direct', f'legacy reference implementation of KL loss only supports mode="direct", not {repr(mode)}'
assert reduction == 'batch_mean', f'legacy reference implementation of KL loss only supports reduction="batch_mean", not {repr(reduction)}'
assert reduction == 'mean_sum', f'legacy reference implementation of KL loss only supports reduction="mean_sum", not {repr(reduction)}'
# Calculate KL divergence
kl_values = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
# Sum KL divergence across latent vector for each sample
Expand Down
86 changes: 64 additions & 22 deletions disent/frameworks/helper/reconstructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import warnings
from typing import final

import torch
import torch.nn.functional as F

from disent.frameworks.helper.reductions import loss_reduction


Expand All @@ -44,18 +46,22 @@ def activate(self, x):
raise NotImplementedError

@final
def training_compute_loss(self, x_partial_recon: torch.Tensor, x_targ: torch.Tensor, reduction: str = 'batch_mean') -> torch.Tensor:
def training_compute_loss(self, x_partial_recon: torch.Tensor, x_targ: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
"""
Takes in an **unactivated** tensor from the model
as well as an original target from the dataset.
:return: The computed mean loss
:return: The computed reduced loss
"""
assert x_partial_recon.shape == x_targ.shape
batch_loss = self._compute_batch_loss(x_partial_recon, x_targ)
batch_loss = self._compute_unreduced_loss(x_partial_recon, x_targ)
loss = loss_reduction(batch_loss, reduction=reduction)
return loss

def _compute_batch_loss(self, x_partial_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor:
def _compute_unreduced_loss(self, x_partial_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor:
"""
Compute the loss without applying a reduction.
- loss tensor should be the same shapes as the input tensors
"""
raise NotImplementedError


Expand All @@ -80,13 +86,9 @@ def activate(self, x):
# - sigmoid is numerically not suitable with MSE
return 0.5 * (x + 1)

def _compute_batch_loss(self, x_partial_recon, x_targ):
def _compute_unreduced_loss(self, x_partial_recon, x_targ):
return F.mse_loss(self.activate(x_partial_recon), x_targ, reduction='none')

@staticmethod
def LEGACY_training_compute_loss(x_recon, x_target, reduction: str = 'batch_mean'):
raise NotImplementedError('LEGACY mse version does not exist!')


class ReconstructionLossBce(ReconstructionLoss):
"""
Expand All @@ -99,22 +101,49 @@ def activate(self, x):
# it to the range [0, 1] here to match the targets.
return torch.sigmoid(x)

def _compute_batch_loss(self, x_partial_recon, x_targ):
return F.binary_cross_entropy_with_logits(x_partial_recon, x_targ, reduction='none')

@staticmethod
def LEGACY_training_compute_loss(x_recon, x_target, reduction: str = 'batch_mean'):
def _compute_unreduced_loss(self, x_partial_recon, x_targ):
"""
Computes the Bernoulli loss for the sigmoid activation function
FROM: https://github.com/google-research/disentanglement_lib/blob/76f41e39cdeff8517f7fba9d57b09f35703efca9/disentanglement_lib/methods/shared/losses.py
REFERENCE:
https://github.com/google-research/disentanglement_lib/blob/76f41e39cdeff8517f7fba9d57b09f35703efca9/disentanglement_lib/methods/shared/losses.py
- the same when reduction=='mean_sum' for super().training_compute_loss()
REFERENCE ALT:
https://github.com/YannDubs/disentangling-vae/blob/master/disvae/models/losses.py
"""
assert reduction == 'batch_mean', f'legacy reference implementation of BCE loss only supports reduction="batch_mean", not {repr(reduction)}'
# x, x_recon = x.view(x.shape[0], -1), x_recon.view(x.shape[0], -1)
# per_sample_loss = F.binary_cross_entropy_with_logits(x_recon, x, reduction='none').sum(axis=1) # tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_recon, labels=x), axis=1)
# reconstruction_loss = per_sample_loss.mean() # tf.reduce_mean(per_sample_loss)
# ALTERNATIVE IMPLEMENTATION https://github.com/YannDubs/disentangling-vae/blob/master/disvae/models/losses.py
assert x_recon.shape == x_target.shape
return F.binary_cross_entropy_with_logits(x_recon, x_target, reduction="sum") / len(x_target)
return F.binary_cross_entropy_with_logits(x_partial_recon, x_targ, reduction='none')


# ========================================================================= #
# Reconstruction Distributions #
# ========================================================================= #


class ReconstructionLossBernoulli(ReconstructionLossBce):
def _compute_unreduced_loss(self, x_partial_recon, x_targ):
# This is exactly the same as the BCE version, but more 'correct'.
return -torch.distributions.Bernoulli(logits=x_partial_recon).log_prob(x_targ)


class ReconstructionLossContinuousBernoulli(ReconstructionLossBce):
"""
The continuous Bernoulli: fixing a pervasive error in variational autoencoders
- Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
- https://arxiv.org/abs/1907.06845
"""
def _compute_unreduced_loss(self, x_partial_recon, x_targ):
warnings.warn('Using continuous bernoulli distribution for reconstruction loss. This is not recommended!')
# I think there is something wrong with this...
# weird values...
return -torch.distributions.ContinuousBernoulli(logits=x_partial_recon, lims=(0.49, 0.51)).log_prob(x_targ)


class ReconstructionLossNormal(ReconstructionLossMse):
def _compute_unreduced_loss(self, x_partial_recon, x_targ):
warnings.warn('Using normal distribution for reconstruction loss. This is not recommended!')
# this is almost the same as MSE, but scaled with a tiny offset
# A value for scale should actually be passed...
return -torch.distributions.Normal(self.activate(x_partial_recon), 1.0).log_prob(x_targ)


# ========================================================================= #
Expand All @@ -124,9 +153,22 @@ def LEGACY_training_compute_loss(x_recon, x_target, reduction: str = 'batch_mean

def make_reconstruction_loss(name) -> ReconstructionLoss:
if name == 'mse':
# from the normal distribution
# binary values only in the set {0, 1}
return ReconstructionLossMse()
elif name == 'bce':
# from the bernoulli distribution
return ReconstructionLossBce()
elif name == 'bernoulli':
# reduces to bce
# binary values only in the set {0, 1}
return ReconstructionLossBernoulli()
elif name == 'continuous_bernoulli':
# bernoulli with a computed offset to handle values in the range [0, 1]
return ReconstructionLossContinuousBernoulli()
elif name == 'normal':
# handle all real values
return ReconstructionLossNormal()
else:
raise KeyError(f'Invalid vae reconstruction loss: {name}')

Expand Down
10 changes: 5 additions & 5 deletions disent/frameworks/helper/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ def loss_reduction_mean(x: torch.Tensor) -> torch.Tensor:
return x.mean()


def loss_reduction_batch_mean(x: torch.Tensor) -> torch.Tensor:
def loss_reduction_mean_sum(x: torch.Tensor) -> torch.Tensor:
return x.reshape(x.shape[0], -1).sum(dim=-1).mean()


_LOSS_REDUCTION_STRATEGIES = {
'none': lambda tensor: tensor,
# 'none': lambda tensor: tensor,
'sum': loss_reduction_sum,
'mean': loss_reduction_mean,
'batch_mean': loss_reduction_batch_mean,
'mean_sum': loss_reduction_mean_sum,
}


def loss_reduction(tensor: torch.Tensor, reduction='batch_mean'):
def loss_reduction(tensor: torch.Tensor, reduction='mean'):
return _LOSS_REDUCTION_STRATEGIES[reduction](tensor)


Expand All @@ -65,7 +65,7 @@ def get_mean_loss_scale(x: torch.Tensor, reduction: str):
assert 2 <= x.ndim <= 4, 'unsupported number of dims, must be one of: BxC, BxHxW, BxCxHxW'

# get the loss scaling
if reduction == 'batch_mean':
if reduction == 'mean_sum':
return np.prod(x.shape[1:]) # MEAN(B, SUM(C x H x W))
elif reduction == 'mean':
return 1
Expand Down
20 changes: 19 additions & 1 deletion disent/frameworks/vae/unsupervised/_betavae.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,25 @@ class BetaVae(Vae):

@dataclass
class cfg(Vae.cfg):
beta: float = 4
# BETA SCALING:
# =============
# when using different loss reduction modes we need to scale beta to
# preserve the ratio between loss components, by scaling beta.
# -- for loss_reduction='mean' we usually have:
# loss = mean_recon_loss + beta * mean_kl_loss
# -- for loss_reduction='mean_sum' we usually have:
# loss = (H*W*C) * mean_recon_loss + beta * (z_size) * mean_kl_loss
# So when switching from one mode to the other, we need to scale beta to preserve these loss ratios.
# -- 'mean_sum' to 'mean':
# beta <- beta * (z_size) / (H*W*C)
# -- 'mean' to 'mean_sum':
# beta <- beta * (H*W*C) / (z_size)
# We obtain an equivalent beta for 'mean_sum' to 'mean':
# -- given values: beta=4 for 'mean_sum', with (H*W*C)=(64*64*3) and z_size=9
# beta = beta * ((z_size) / (H*W*C))
# ~= 4 * 0.0007324
# ~= 0,003
beta: float = 0.003 # approximately equal to mean_sum beta of 4

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg)
Expand Down
3 changes: 1 addition & 2 deletions disent/frameworks/vae/unsupervised/_dfcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,11 @@ def num(self):
def __call__(self, x_recon, x_targ):
return self.compute_loss(x_recon, x_targ)

def compute_loss(self, x_recon, x_targ, reduction='batch_mean'):
def compute_loss(self, x_recon, x_targ, reduction='mean'):
"""
x_recon and x_targ data should be an unnormalized RGB batch of
data [B x C x H x W] in the range [0, 1].
"""

features_recon = self._extract_features(x_recon)
features_targ = self._extract_features(x_targ)
# compute losses
Expand Down
Loading

0 comments on commit bc3e889

Please sign in to comment.