From 2f646f3ec4a673b78e71dfd3fc69ec899f5e9e5b Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 5 Aug 2022 23:15:31 +0200 Subject: [PATCH 1/6] compute_pairwise_loss for DfcLossModule --- .../frameworks/vae/_unsupervised__dfcvae.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/disent/frameworks/vae/_unsupervised__dfcvae.py b/disent/frameworks/vae/_unsupervised__dfcvae.py index a44cfefe..ef019d8d 100644 --- a/disent/frameworks/vae/_unsupervised__dfcvae.py +++ b/disent/frameworks/vae/_unsupervised__dfcvae.py @@ -40,6 +40,7 @@ from disent.frameworks.helper.util import compute_ave_loss from disent.frameworks.vae._unsupervised__betavae import BetaVae +from disent.nn.loss.reduction import batch_loss_reduction from disent.nn.loss.reduction import get_mean_loss_scale from disent.dataset.transform.functional import check_tensor @@ -132,6 +133,24 @@ def __init__(self, feature_layers: Optional[List[Union[str, int]]] = None, input assert input_mode in {'none', 'clamp', 'assert'} self.input_mode = input_mode + def compute_pairwise_loss(self, x_recon, x_targ, reduction='mean'): + """ + THIS DOES NOT HAVE LOSS SCALING, LIKE `compute_loss` + + x_recon and x_targ data should be an unnormalized RGB batch of + data [B x C x H x W] in the range [0, 1]. + """ + features_recon = self._extract_features(x_recon) + features_targ = self._extract_features(x_targ) + # compute losses + feature_loss = 0.0 + for (f_recon, f_targ) in zip(features_recon, features_targ): + loss = F.mse_loss(f_recon, f_targ, reduction='none') + feature_loss += batch_loss_reduction(loss, reduction=reduction) + # checks + assert (feature_loss.ndim == 1) and (len(feature_loss) == len(x_recon)) + return feature_loss + def compute_loss(self, x_recon, x_targ, reduction='mean'): """ x_recon and x_targ data should be an unnormalized RGB batch of From 950ff1a0a520414f261ea0005186caed7ed7180b Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 5 Aug 2022 23:16:51 +0200 Subject: [PATCH 2/6] minor fix --- disent/dataset/_base.py | 3 +++ disent/frameworks/vae/_unsupervised__dotvae.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/disent/dataset/_base.py b/disent/dataset/_base.py index 7ab32153..5f923f61 100644 --- a/disent/dataset/_base.py +++ b/disent/dataset/_base.py @@ -354,6 +354,9 @@ def dataset_sample_elems(self, num_samples: int, mode: str, return_indices: bool # Batches -- Ground Truth Only # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # TODO: batches should be obtained from indices + # - the wrapped gt datasets should handle generating these indices, eg. factor traversals etc. + @groundtruth_only def dataset_batch_from_factors(self, factors: np.ndarray, mode: str, collate: bool = True): """Get a batch of observations X from a batch of factors Y.""" diff --git a/disent/frameworks/vae/_unsupervised__dotvae.py b/disent/frameworks/vae/_unsupervised__dotvae.py index 582fa666..ee1b427f 100644 --- a/disent/frameworks/vae/_unsupervised__dotvae.py +++ b/disent/frameworks/vae/_unsupervised__dotvae.py @@ -33,7 +33,7 @@ from disent.frameworks.helper.reconstructions import make_reconstruction_loss from disent.frameworks.helper.reconstructions import ReconLossHandler -from disent.frameworks.vae import AdaNegTripletVae +from disent.frameworks.vae._supervised__adaneg_tvae import AdaNegTripletVae from disent.nn.loss.triplet_mining import configured_idx_mine From d9de84ba6f9e856e32df34b758e2f22424deba94 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 5 Aug 2022 23:32:16 +0200 Subject: [PATCH 3/6] fix requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 813be1c0..6cc726f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pip>=21.0 numpy>=1.19.0 torch>=1.9.0 torchvision>=0.10.0 -pytorch-lightning>=1.4.0 +pytorch-lightning>=1.4.0,<1.7 torch_optimizer>=0.1.0 scipy>=1.7.0 scikit-learn>=0.24.2 From 04c1eecdaeb085f484f20e1e147abe1df2feb975 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 5 Aug 2022 23:32:27 +0200 Subject: [PATCH 4/6] circular import test --- tests/test_000_import.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/test_000_import.py diff --git a/tests/test_000_import.py b/tests/test_000_import.py new file mode 100644 index 00000000..f74b3ed8 --- /dev/null +++ b/tests/test_000_import.py @@ -0,0 +1,31 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# THIS TEST FILE SHOULD ALWAYS BE LOADED AND RUN FIRST +from disent.frameworks.vae import BetaVae + + +def test_000_import(): + assert BetaVae From b700cf8d571773cf3a44ab4f32a93c209bbbd6c3 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 5 Aug 2022 23:22:36 +0200 Subject: [PATCH 5/6] fix circular imports --- disent/frameworks/{ae => }/_ae_mixin.py | 0 disent/frameworks/ae/_unsupervised__ae.py | 3 +-- disent/frameworks/vae/_unsupervised__vae.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) rename disent/frameworks/{ae => }/_ae_mixin.py (100%) diff --git a/disent/frameworks/ae/_ae_mixin.py b/disent/frameworks/_ae_mixin.py similarity index 100% rename from disent/frameworks/ae/_ae_mixin.py rename to disent/frameworks/_ae_mixin.py diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index f8f4a204..8a332e05 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -22,7 +22,6 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import logging from dataclasses import dataclass from numbers import Number from typing import Any @@ -34,7 +33,7 @@ import torch -from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin +from disent.frameworks._ae_mixin import _AeAndVaeMixin from disent.frameworks.helper.util import detach_all from disent.model import AutoEncoder from disent.util.iters import map_all diff --git a/disent/frameworks/vae/_unsupervised__vae.py b/disent/frameworks/vae/_unsupervised__vae.py index 1ad85b33..d431b6d9 100644 --- a/disent/frameworks/vae/_unsupervised__vae.py +++ b/disent/frameworks/vae/_unsupervised__vae.py @@ -34,7 +34,7 @@ import torch from torch.distributions import Distribution -from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin +from disent.frameworks._ae_mixin import _AeAndVaeMixin from disent.frameworks.helper.latent_distributions import LatentDistsHandler from disent.frameworks.helper.latent_distributions import make_latent_distribution from disent.frameworks.helper.util import detach_all From e57f4e26456be6e715eb9441446f8acf6268dd8e Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 5 Aug 2022 23:34:43 +0200 Subject: [PATCH 6/6] version bump v0.6.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3968160b..1c6755d0 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.6.0", + version="0.6.1", python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(),