From 65e64b91a7fb1e2c0203cade5fa62932a3380b9e Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 15 Mar 2021 19:52:27 +0200 Subject: [PATCH 1/7] data traversal experiments --- disent/data/util/state_space.py | 35 ++- experiment/exp/data_traversal/run.py | 220 ++++++++++++++++++ .../exp/{visual_overlap => }/gadfly.mplstyle | 0 experiment/exp/visual_overlap/run.py | 2 +- 4 files changed, 249 insertions(+), 8 deletions(-) create mode 100644 experiment/exp/data_traversal/run.py rename experiment/exp/{visual_overlap => }/gadfly.mplstyle (100%) diff --git a/disent/data/util/state_space.py b/disent/data/util/state_space.py index ae4807c5..c891a99b 100644 --- a/disent/data/util/state_space.py +++ b/disent/data/util/state_space.py @@ -157,19 +157,40 @@ def resample_factors(self, factors, fixed_factor_indices) -> np.ndarray: """ return self.sample_missing_factors(np.array(factors)[..., fixed_factor_indices], fixed_factor_indices) - def sample_random_traversal_factors(self, f_idx: int = None) -> np.ndarray: + def _get_f_idx_and_factors_and_size(self, f_idx: int = None, factors=None, num: int = None): # choose a random factor if not given if f_idx is None: f_idx = np.random.randint(0, self.num_factors) - f_size = self.factor_sizes[f_idx] - # Aka. a traversal along a single factor - # make sequential factors, one randomly sampled list of - # factors, then repeated, with one index mutated as if set by range() - factors = self.sample_factors(size=1) - factors = factors.repeat(f_size, axis=0) + # sample factors if not given + if factors is None: + factors = self.sample_factors(size=1) + else: + factors = factors.reshape((1, self.num_factors)) + # get size if not given + if num is None: + num = self.factor_sizes[f_idx] + else: + assert num > 0 + # generate a traversal + factors = factors.repeat(num, axis=0) + # return everything + return f_idx, factors, num + + def sample_random_traversal_factors(self, f_idx: int = None, factors=None) -> np.ndarray: + f_idx, factors, f_size = self._get_f_idx_and_factors_and_size(f_idx=f_idx, factors=factors, num=None) + # generate traversal factors[:, f_idx] = np.arange(f_size) + # return factors return factors + def sample_random_cycle_factors(self, f_idx: int = None, factors=None, num: int = None): + f_idx, factors, num = self._get_f_idx_and_factors_and_size(f_idx=f_idx, factors=factors, num=num) + # generate traversal + grid = np.linspace(0, self.factor_sizes[f_idx]-1, num=num, endpoint=True) + grid = np.int64(np.around(grid)) + factors[:, f_idx] = grid + # return factors + return factors # ========================================================================= # # Hidden State Space # diff --git a/experiment/exp/data_traversal/run.py b/experiment/exp/data_traversal/run.py new file mode 100644 index 00000000..b54ff16a --- /dev/null +++ b/experiment/exp/data_traversal/run.py @@ -0,0 +1,220 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import itertools +import os +from typing import Union + +import imageio +import numpy as np +from matplotlib import pyplot as plt + +from disent.data.groundtruth import Cars3dData +from disent.data.groundtruth import DSpritesData +from disent.data.groundtruth import GroundTruthData +from disent.data.groundtruth import Shapes3dData +from disent.data.groundtruth import SmallNorbData +from disent.data.groundtruth import XYSquaresData +from disent.dataset.groundtruth import GroundTruthDataset +from disent.util import TempNumpySeed +from disent.visualize.visualize_util import make_image_grid + + +def make_rel_path(*path_segments, is_file=True): + assert not os.path.isabs(os.path.join(*path_segments)), 'path must be relative' + path = os.path.join(os.path.dirname(__file__), *path_segments) + folder_path = os.path.dirname(path) if is_file else path + os.makedirs(folder_path, exist_ok=True) + return path + +def make_rel_path_add_ext(*path_segments, ext='.png'): + # make path + path = make_rel_path(*path_segments, is_file=True) + if not os.path.splitext(path)[1]: + path = f'{path}{ext}' + return path + +def output_image(img, rel_path, save=True, plot=True): + if save and (rel_path is not None): + # convert image type + if img.dtype in (np.float16, np.float32, np.float64, np.float128): + assert np.all(img >= 0) and np.all(img <= 1.0) + img = np.uint8(img * 255) + elif img.dtype in (np.int16, np.int32, np.int64): + assert np.all(img >= 0) and np.all(img <= 255) + img = np.uint8(img) + assert img.dtype == np.uint8, f'unsupported image dtype: {img.dtype}' + # save image + imageio.imsave(make_rel_path_add_ext(rel_path, ext='.png'), img) + if plot: + plt.imshow(img) + plt.show() + return img + + +def convert_f_idxs(gt_data, f_idxs): + if f_idxs is None: + f_idxs = list(range(gt_data.num_factors)) + else: + f_idxs = [(gt_data.factor_names.index(i) if isinstance(i, str) else i) for i in f_idxs] + return f_idxs + + +def make_traversal_grid(gt_data: Union[GroundTruthData, GroundTruthDataset], f_idxs=None, factors=True, num=8): + # get defaults + if not isinstance(gt_data, GroundTruthDataset): + gt_data = GroundTruthDataset(gt_data) + f_idxs = convert_f_idxs(gt_data, f_idxs) + # sample factors + if isinstance(factors, bool): + factors = gt_data.sample_factors(1) if factors else None + # sample traversals + images = [] + for f_idx in f_idxs: + fs = gt_data.sample_random_cycle_factors(f_idx, factors=factors, num=num) + images.append(gt_data.dataset_batch_from_factors(fs, mode='raw') / 255.0) + images = np.stack(images) + # return grid + return images # (F, N, H, W, C) + + +def make_dataset_traversals( + gt_data, + f_idxs=None, num_cols=8, factors=True, + pad=8, bg_color=1.0, border=False, + rel_path=None, save=True, plot=False, + seed=777, +): + with TempNumpySeed(seed): + images = make_traversal_grid(gt_data, f_idxs=f_idxs, num=num_cols, factors=factors) + image = make_image_grid(images.reshape(np.prod(images.shape[:2]), *images.shape[2:]), pad=pad, bg_color=bg_color, border=border, num_cols=num_cols) + output_image(img=image, rel_path=rel_path, save=save, plot=plot) + return image, images + +def plot_dataset_traversals( + gt_data, + f_idxs=None, num_cols=8, factors=True, add_random_traversal=True, + pad=8, bg_color=1.0, border=False, + rel_path=None, save=True, plot=True, + seed=777, + plt_scale=7, offset=0.75, plt_transpose=False, +): + if not isinstance(gt_data, GroundTruthDataset): + gt_data = GroundTruthDataset(gt_data) + f_idxs = convert_f_idxs(gt_data, f_idxs) + # print factors + print(f'{gt_data.data.__class__.__name__}: loaded factors {tuple([gt_data.factor_names[i] for i in f_idxs])} of {gt_data.factor_names}') + # get traversal grid + _, images = make_dataset_traversals( + gt_data, + f_idxs=f_idxs, num_cols=num_cols, factors=factors, + pad=pad, bg_color=bg_color, border=border, + rel_path=None, save=False, plot=False, + seed=seed, + ) + # add random traversal + if add_random_traversal: + with TempNumpySeed(seed): + ran_imgs = gt_data.dataset_sample_batch(num_samples=num_cols, mode='raw') / 255 + images = np.concatenate([ran_imgs[None, ...], images]) + # transpose + if plt_transpose: + images = np.transpose(images, [1, 0, *range(2, images.ndim)]) + # add missing channel + if images.ndim == 4: + images = images[..., None].repeat(3, axis=-1) + assert images.ndim == 5 + # make figure + oW, oH = (0, offset*0.5) if plt_transpose else (offset, 0) + H, W, _, _, C = images.shape + assert C == 3 + cm = 1 / 2.54 + fig, axs = plt.subplots(H, W, figsize=(oW + cm*W*plt_scale, oH + cm*H*plt_scale)) + axs = np.array(axs) + # plot images + for y, x in itertools.product(range(H), range(W)): + img, ax = images[y, x], axs[y, x] + ax.imshow(img) + i, j = (y, x) if plt_transpose else (x, y) + if (i == H-1) if plt_transpose else (i == 0): + label = 'random' if (add_random_traversal and (j == 0)) else gt_data.factor_names[f_idxs[j-int(add_random_traversal)]] + (ax.set_xlabel if plt_transpose else ax.set_ylabel)(label, fontsize=26) + # ax.set_axis_off() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.get_xaxis().set_ticklabels([]) + ax.get_yaxis().set_ticklabels([]) + plt.tight_layout() + # save and show + if save and (rel_path is not None): + plt.savefig(make_rel_path_add_ext(rel_path, ext='.png')) + if plot: + plt.show() + + +if __name__ == '__main__': + + # matplotlib style + plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) + + # options + all_squares = True + add_random_traversal = True + num_cols = 7 + + # save image + for i in ([1, 2, 3, 4, 5, 6, 7, 8] if all_squares else [1, 8]): + plot_dataset_traversals( + XYSquaresData(grid_spacing=i, max_placements=8, no_warnings=True), + factors=None, + rel_path=f'plots/xy-squares-traversal-spacing{i}', + f_idxs=None, seed=7, add_random_traversal=add_random_traversal, num_cols=num_cols + ) + + plot_dataset_traversals( + Shapes3dData(), + factors=None, + rel_path=f'plots/shapes3d-traversal', + f_idxs=None, seed=47, add_random_traversal=add_random_traversal, num_cols=num_cols + ) + + plot_dataset_traversals( + DSpritesData(), + factors=None, + rel_path=f'plots/dsprites-traversal', + f_idxs=None, seed=47, add_random_traversal=add_random_traversal, num_cols=num_cols + ) + + plot_dataset_traversals( + SmallNorbData(), + factors=None, + rel_path=f'plots/smallnorb-traversal', + f_idxs=None, seed=47, add_random_traversal=add_random_traversal, num_cols=num_cols + ) + + plot_dataset_traversals( + Cars3dData(), + factors=None, + rel_path=f'plots/cars3d-traversal', + f_idxs=None, seed=47, add_random_traversal=add_random_traversal, num_cols=num_cols + ) diff --git a/experiment/exp/visual_overlap/gadfly.mplstyle b/experiment/exp/gadfly.mplstyle similarity index 100% rename from experiment/exp/visual_overlap/gadfly.mplstyle rename to experiment/exp/gadfly.mplstyle diff --git a/experiment/exp/visual_overlap/run.py b/experiment/exp/visual_overlap/run.py index c95f75e7..07496e24 100644 --- a/experiment/exp/visual_overlap/run.py +++ b/experiment/exp/visual_overlap/run.py @@ -329,7 +329,7 @@ def plot_unique_count(dfs, save_name: str = None, show_plt: bool = True, fig_l_p if __name__ == '__main__': # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), 'gadfly.mplstyle')) + plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) # common settings SHARED_SETTINGS = dict( From 66b1acb80d072e1f82c8492afcb645278ef58c7d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 19 Mar 2021 01:14:43 +0200 Subject: [PATCH 2/7] Math Helper: Covariance, Pearson's, Spearman's, Generalised Mean --- disent/frameworks/vae/unsupervised/_dipvae.py | 24 +-- disent/metrics/_flatness.py | 4 +- disent/util/math.py | 189 ++++++++++++++++++ tests/test_math.py | 74 +++++++ 4 files changed, 267 insertions(+), 24 deletions(-) create mode 100644 disent/util/math.py create mode 100644 tests/test_math.py diff --git a/disent/frameworks/vae/unsupervised/_dipvae.py b/disent/frameworks/vae/unsupervised/_dipvae.py index d0a33245..aaa5d29b 100644 --- a/disent/frameworks/vae/unsupervised/_dipvae.py +++ b/disent/frameworks/vae/unsupervised/_dipvae.py @@ -30,6 +30,7 @@ from disent.frameworks.helper.util import compute_ave_loss_and_logs from disent.frameworks.vae.unsupervised._betavae import BetaVae +from disent.util.math import torch_cov_matrix # ========================================================================= # @@ -114,7 +115,7 @@ def _dip_compute_regulariser(self, cov_matrix): def _dip_estimate_cov_matrix(self, d_posterior: Normal): z_mean, z_var = d_posterior.mean, d_posterior.variance # compute covariance over batch - cov_z_mean = estimate_covariance(z_mean) + cov_z_mean = torch_cov_matrix(z_mean) # compute covariance matrix based on mode if self.cfg.dip_mode == "i": cov_matrix = cov_z_mean @@ -128,27 +129,6 @@ def _dip_estimate_cov_matrix(self, d_posterior: Normal): return cov_matrix -# ========================================================================= # -# Helper # -# ========================================================================= # - - -def estimate_covariance(xs): - """ - Calculate the covariance of multivariate random variable from samples - over a batch (eg. z_mean(s) calculated from minibatch with shape (BxZ)) - - Reference: https://github.com/paruby/DIP-VAE/blob/master/dip_vae.py - """ - # E[mu mu.T] - E_x_x_t = torch.mean(xs.unsqueeze(2) * xs.unsqueeze(1), dim=0) - # E[mu] (mean of distributions) - E_x = torch.mean(xs, dim=0) - # covariance matrix of model mean - cov_x = E_x_x_t - (E_x.unsqueeze(1) * E_x.unsqueeze(0)) - # done! - return cov_x - - # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/metrics/_flatness.py b/disent/metrics/_flatness.py index 514af440..e2a48aad 100644 --- a/disent/metrics/_flatness.py +++ b/disent/metrics/_flatness.py @@ -32,7 +32,6 @@ from typing import Iterable from typing import Union -import numpy as np import torch from torch.utils.data.dataloader import default_collate @@ -201,7 +200,7 @@ def aggregate_measure_distances_along_factor( # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # return { p: { - 'ave_width': measures['widths'].mean(dim=0), # shape: (repeats,) -> () + 'ave_width': measures['widths'].mean(dim=0), # shape: (repeats,) -> () 'ave_delta': measures['deltas'].mean(dim=[0, 1]), # shape: (repeats, factor_size - 1) -> () 'ave_angle': measures['angles'].mean(dim=0), # shape: (repeats,) -> () } for p, measures in default_collate(p_measures).items() @@ -220,6 +219,7 @@ def encode_all_along_factor(ground_truth_dataset, representation_function, f_idx sequential_zs = encode_all_factors(ground_truth_dataset, representation_function, factors=factors, batch_size=batch_size) return sequential_zs + def encode_all_factors(ground_truth_dataset, representation_function, factors, batch_size: int) -> torch.Tensor: zs = [] with torch.no_grad(): diff --git a/disent/util/math.py b/disent/util/math.py new file mode 100644 index 00000000..049c04b0 --- /dev/null +++ b/disent/util/math.py @@ -0,0 +1,189 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from typing import List +from typing import Optional +from typing import Union + +import logging +import numpy as np +import torch + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# pytorch math correlation functions # +# ========================================================================= # + + +def torch_cov_matrix(xs: torch.Tensor): + """ + Calculate the covariance matrix of multiple samples (N) of random vectors of size (X) + https://en.wikipedia.org/wiki/Covariance_matrix + - The input shape is: (N, X) + - The output shape is: (X, X) + + This should be the same as: + np.cov(xs, rowvar=False, ddof=0) + """ + # NOTE: + # torch.mm is strict matrix multiplication + # however if we multiply arrays with broadcasting: + # size(3, 1) * size(1, 2) -> size(3, 2) # broadcast, not matmul + # size(1, 3) * size(2, 1) -> size(2, 3) # broadcast, not matmul + # CHECK: + assert xs.ndim == 2 # (N, X) + Rxx = torch.mean(xs[:, :, None] * xs[:, None, :], dim=0) # (X, X) + ux = torch.mean(xs, dim=0) # (X,) + Kxx = Rxx - (ux[:, None] * ux[None, :]) # (X, X) + return Kxx + + +def torch_corr_matrix(xs: torch.Tensor): + """ + Calculate the pearson's correlation matrix of multiple samples (N) of random vectors of size (X) + https://en.wikipedia.org/wiki/Pearson_correlation_coefficient + https://en.wikipedia.org/wiki/Covariance_matrix + - The input shape is: (N, X) + - The output shape is: (X, X) + + This should be the same as: + np.corrcoef(xs, rowvar=False, ddof=0) + """ + Kxx = torch_cov_matrix(xs) + diag_Kxx = torch.rsqrt(torch.diagonal(Kxx)) + corr = Kxx * (diag_Kxx[:, None] * diag_Kxx[None, :]) + return corr + + +def torch_rank_corr_matrix(xs: torch.Tensor): + """ + Calculate the spearman's rank correlation matrix of multiple samples (N) of random vectors of size (X) + https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient + - The input shape is: (N, X) + - The output shape is: (X, X) + + Pearson's correlation measures linear relationships + Spearman's correlation measures monotonic relationships (whether linear or not) + - defined in terms of the pearson's correlation matrix of the rank variables + + TODO: check, be careful of repeated values, this might not give the correct result? + """ + rs = torch.argsort(xs, dim=0, descending=False) + return torch_corr_matrix(rs.to(xs.dtype)) + + +# aliases +torch_pearsons_corr_matrix = torch_corr_matrix +torch_spearmans_corr_matrix = torch_rank_corr_matrix + + +# ========================================================================= # +# pytorch math helper functions # +# ========================================================================= # + + +def torch_tril_mean(mat: torch.Tensor, diagonal=-1): + """ + compute the mean of the lower triangular matrix. + """ + # checks + N, M = mat.shape + assert N == M + assert diagonal == -1 + # compute + n = (N*(N-1))/2 + mean = torch.tril(mat, diagonal=diagonal).sum() / n + # done + return mean + + +# ========================================================================= # +# pytorch mean functions # +# ========================================================================= # + + +_DimTypeHint = Optional[Union[int, List[int]]] + +_POS_INF = float('inf') +_NEG_INF = float('-inf') + +_GENERALIZED_MEAN_MAP = { + 'maximum': _POS_INF, + 'quadratic': 2, + 'arithmetic': 1, + 'geometric': 0, + 'harmonic': -1, + 'minimum': _NEG_INF, +} + + +def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[int, str] = 1): + """ + Generalised Mean + - is this implementation actually correct? + """ + if isinstance(p, str): + p = _GENERALIZED_MEAN_MAP[p] + # warn if the type is wrong + if xs.dtype != torch.float64: + log.warning(f'Input tensor to generalised mean might not have the required precision, type is {xs.dtype} not {torch.float64}.') + # compute the specific extreme cases + if p == _POS_INF: + return torch.max(xs, dim=dim).values + elif p == _NEG_INF: + return torch.min(xs, dim=dim).values + # compute the number of elements being averaged + if dim is None: + dim = list(range(xs.ndim)) + n = torch.prod(torch.as_tensor(xs.shape)[dim]) + # compute the specific cases + if p == 0: + # geometric mean + # orig numerically unstable: torch.prod(xs, dim=dim) ** (1 / n) + return torch.exp((1 / n) * torch.sum(torch.log(xs), dim=dim)) + elif p == 1: + # arithmetic mean + return torch.mean(xs, dim=dim) + else: + # generalised mean + return ((1/n) * torch.sum(xs ** p, dim=dim)) ** (1/p) + + +def torch_mean_quadratic(xs, dim: _DimTypeHint = None): + return torch_mean_generalized(xs, dim=dim, p='quadratic') + + +def torch_mean_geometric(xs, dim: _DimTypeHint = None): + return torch_mean_generalized(xs, dim=dim, p='geometric') + + +def torch_mean_harmonic(xs, dim: _DimTypeHint = None): + return torch_mean_generalized(xs, dim=dim, p='harmonic') + + +# ========================================================================= # +# end # +# ========================================================================= # diff --git a/tests/test_math.py b/tests/test_math.py new file mode 100644 index 00000000..61ece977 --- /dev/null +++ b/tests/test_math.py @@ -0,0 +1,74 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +import torch +from scipy.stats import gmean +from scipy.stats import hmean + +from disent.util import to_numpy +from disent.util.math import torch_corr_matrix +from disent.util.math import torch_cov_matrix +from disent.util.math import torch_mean_generalized + + +def test_cov_corr(): + + for i in range(5, 1000, 250): + for j in range(2, 100, 25): + + # these match when torch.float64 is used, not when torch float32 is used... + xs = torch.randn(i, j, dtype=torch.float64) + + np_cov = torch.from_numpy(np.cov(to_numpy(xs), rowvar=False, ddof=0)).to(xs.dtype) + np_cor = torch.from_numpy(np.corrcoef(to_numpy(xs), rowvar=False, ddof=0)).to(xs.dtype) + + cov = torch_cov_matrix(xs) + cor = torch_corr_matrix(xs) + + assert torch.allclose(np_cov, cov) + assert torch.allclose(np_cor, cor) + + +def test_generalised_mean(): + xs = torch.abs(torch.randn(2, 1000, 3, dtype=torch.float64)) + + # normal + assert torch.allclose(torch_mean_generalized(xs, p='arithmetic', dim=1), torch.mean(xs, dim=1)) + assert torch.allclose(torch_mean_generalized(xs, p=1, dim=1), torch.mean(xs, dim=1)) + + # scipy equivalents + assert torch.allclose(torch_mean_generalized(xs, p='geometric', dim=1), torch.as_tensor(gmean(xs, axis=1))) + assert torch.allclose(torch_mean_generalized(xs, p='harmonic', dim=1), torch.as_tensor(hmean(xs, axis=1))) + assert torch.allclose(torch_mean_generalized(xs, p=0, dim=1), torch.as_tensor(gmean(xs, axis=1))) + assert torch.allclose(torch_mean_generalized(xs, p=-1, dim=1), torch.as_tensor(hmean(xs, axis=1))) + assert torch.allclose(torch_mean_generalized(xs, p=0), torch.as_tensor(gmean(xs, axis=None))) # scipy default axis is 0 + assert torch.allclose(torch_mean_generalized(xs, p=-1), torch.as_tensor(hmean(xs, axis=None))) # scipy default axis is 0 + + # min max + assert torch.allclose(torch_mean_generalized(xs, p='maximum', dim=1), torch.max(xs, dim=1).values) + assert torch.allclose(torch_mean_generalized(xs, p='minimum', dim=1), torch.min(xs, dim=1).values) + assert torch.allclose(torch_mean_generalized(xs, p=np.inf, dim=1), torch.max(xs, dim=1).values) + assert torch.allclose(torch_mean_generalized(xs, p=-np.inf, dim=1), torch.min(xs, dim=1).values) + From 115040ab10fe29ca8a6a3eec22d41975e1e2d276 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 19 Mar 2021 01:14:58 +0200 Subject: [PATCH 3/7] state space fixes? --- disent/data/util/state_space.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/disent/data/util/state_space.py b/disent/data/util/state_space.py index c891a99b..4765abe0 100644 --- a/disent/data/util/state_space.py +++ b/disent/data/util/state_space.py @@ -82,7 +82,7 @@ def pos_to_idx(self, positions) -> np.ndarray: - positions are lists of integers, with each element < their corresponding factor size - indices are integers < size """ - positions = np.array(positions).T + positions = np.moveaxis(positions, source=-1, destination=0) return np.ravel_multi_index(positions, self._factor_sizes) def idx_to_pos(self, indices) -> np.ndarray: @@ -92,7 +92,7 @@ def idx_to_pos(self, indices) -> np.ndarray: - positions are lists of integers, with each element < their corresponding factor size """ positions = np.unravel_index(indices, self._factor_sizes) - return np.array(positions).T + return np.moveaxis(positions, source=0, destination=-1) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Sampling Functions - any dim array, only last axis counts! # @@ -290,4 +290,4 @@ def sample_random_cycle_factors(self, f_idx: int = None, factors=None, num: int # """ # get the original index of factors # """ -# return self._state_to_orig_idx[self._states.pos_to_idx(factors)] \ No newline at end of file +# return self._state_to_orig_idx[self._states.pos_to_idx(factors)] From 3bdf5ac6c4a76f89be459bbe1f90bbdba06a689f Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 19 Mar 2021 01:25:33 +0200 Subject: [PATCH 4/7] tests for previous commit --- tests/test_state_space.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_state_space.py b/tests/test_state_space.py index bedcfad1..c3dc3d89 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -37,7 +37,7 @@ [2, 3, 4, 5], [2, 3, 4], [1, 2, 3], - [1, 100, 1], + [1, 33, 1], [1, 1, 1], [1], ] @@ -59,10 +59,13 @@ def test_discrete_state_space_one_to_one(): for factor_sizes in FACTOR_SIZES: states = StateSpace(factor_sizes=factor_sizes) # check that entire range of values is generated + k = np.random.randint(1, 5) # chances of this failing are extremely low, but it could happen... - pos_0 = states.sample_factors(100_000) - assert np.all(pos_0.max(axis=0) == (states.factor_sizes - 1)) - assert np.all(pos_0.min(axis=0) == 0) + pos_0 = states.sample_factors([int(100_000 ** (1/k))] * k) + # check random values are in the right ranges + all_dims = tuple(range(pos_0.ndim)) + assert np.all(np.max(pos_0, axis=all_dims[:-1]) == (states.factor_sizes - 1)) + assert np.all(np.min(pos_0, axis=all_dims[:-1]) == 0) # check that converting between them keeps values the same idx_0 = states.pos_to_idx(pos_0) pos_1 = states.idx_to_pos(idx_0) From f61d51266ecf5370fc288bcd54a0a9dc51b6d180 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 19 Mar 2021 02:33:00 +0200 Subject: [PATCH 5/7] dual flatness metric -- correlation + rank correlation + swap ratio --- disent/metrics/_dual_flatness.py | 273 +++++++++++++++++++++++++++++++ disent/util/math.py | 26 ++- 2 files changed, 297 insertions(+), 2 deletions(-) create mode 100644 disent/metrics/_dual_flatness.py diff --git a/disent/metrics/_dual_flatness.py b/disent/metrics/_dual_flatness.py new file mode 100644 index 00000000..130ec92b --- /dev/null +++ b/disent/metrics/_dual_flatness.py @@ -0,0 +1,273 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +import logging + +import numpy as np +import torch +from torch.utils.data.dataloader import default_collate + +from disent.dataset.groundtruth import GroundTruthDataset +from disent.metrics._flatness import encode_all_along_factor +from disent.metrics._flatness import encode_all_factors +from disent.metrics._flatness import get_device +from disent.util import iter_chunks +from disent.util import to_numpy +from disent.util.math import torch_corr_matrix +from disent.util.math import torch_mean_generalized +from disent.util.math import torch_nan_to_num +from disent.util.math import torch_rank_corr_matrix +from disent.util.math import torch_tril_mean + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# flatness # +# ========================================================================= # + + +def metric_dual_flatness( + ground_truth_dataset: GroundTruthDataset, + representation_function: callable, + factor_repeats: int = 1024, + batch_size: int = 64, +): + """ + Computes the dual flatness metrics (ordering & linearity): + swap_ratio: percent of correctly ordered ground truth factors in the latent space + ave_corr: average of the correlation matrix (Pearson's) for latent traversals + ave_rank_corr: average of the rank correlation matrix (Spearman's) for latent traversals + + Args: + ground_truth_dataset: GroundTruthData to be sampled from. + representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. + factor_repeats: how many times to repeat a traversal along each factors, these are then averaged together. + batch_size: Batch size to process at any time while generating representations, should not effect metric results. + Returns: + Dictionary with metrics + """ + fs_measures = aggregate_measure_distances_along_all_factors(ground_truth_dataset, representation_function, repeats=factor_repeats, batch_size=batch_size) + # get the means + def multi_mean(name, ps=('arithmetic', 'geometric', 'harmonic')): + return { + f'dual_flatness.{name}.{p}': to_numpy(torch_mean_generalized(fs_measures[name].to(torch.float64), dim=0, p=p).to(torch.float32)) + for p in ps + } + results = { + **multi_mean('ave_corr', ps=('arithmetic', 'geometric')), + **multi_mean('ave_rank_corr', ps=('arithmetic', 'geometric')), + # traversals + # **multi_mean('swap_ratio_l1', ps=('arithmetic', 'geometric')), + **multi_mean('swap_ratio_l2', ps=('arithmetic', 'geometric')), + # any pairs + # **multi_mean('ran_swap_ratio_l1', ps=('arithmetic',)), + **multi_mean('ran_swap_ratio_l2', ps=('arithmetic',)), + } + # convert values from torch + return {k: float(v) for k, v in results.items()} + + +def aggregate_measure_distances_along_all_factors( + ground_truth_dataset, + representation_function, + repeats: int, + batch_size: int, +) -> dict: + # COMPUTE AGGREGATES FOR EACH FACTOR + # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # + fs_measures = default_collate([ + aggregate_measure_distances_along_factor(ground_truth_dataset, representation_function, f_idx=f_idx, repeats=repeats, batch_size=batch_size) + for f_idx in range(ground_truth_dataset.num_factors) + ]) + + # COMPUTE RANDOM + # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # + values = [] + num_samples = int(np.mean(ground_truth_dataset.factor_sizes) * repeats) + for idxs in iter_chunks(range(num_samples), batch_size): + # encode factors + factors = ground_truth_dataset.sample_factors(size=len(idxs)) + zs = encode_all_factors(ground_truth_dataset, representation_function, factors, batch_size=batch_size) + # get random triplets from factors + rai, rpi, rni = np.random.randint(0, len(factors), size=(3, len(factors) * 4)) + rai, rpi, rni = reorder_by_factor_dist(factors, rai, rpi, rni) + # check differences + swap_ratio_l1, swap_ratio_l2 = compute_swap_ratios(zs[rai], zs[rpi], zs[rni]) + values.append({'ran_swap_ratio_l1': swap_ratio_l1, 'ran_swap_ratio_l2': swap_ratio_l2}) + # return all + return { + **fs_measures, + **default_collate(values), + } + + +def reorder_by_factor_dist(factors, rai, rpi, rni): + a_fs, p_fs, n_fs = factors[rai], factors[rpi], factors[rni] + # sort all + d_ap = np.linalg.norm(a_fs - p_fs, ord=1, axis=-1) + d_an = np.linalg.norm(a_fs - n_fs, ord=1, axis=-1) + # swap + swap_mask = d_ap <= d_an + rpi_NEW = np.where(swap_mask, rpi, rni) + rni_NEW = np.where(swap_mask, rni, rpi) + # return new + return rai, rpi_NEW, rni_NEW + + +def compute_swap_ratios(a_zs, p_zs, n_zs): + ap_delta_l1, an_delta_l1 = torch.norm(a_zs - p_zs, dim=-1, p=1), torch.norm(a_zs - n_zs, dim=-1, p=1) + ap_delta_l2, an_delta_l2 = torch.norm(a_zs - p_zs, dim=-1, p=2), torch.norm(a_zs - n_zs, dim=-1, p=2) + swap_ratio_l1 = (ap_delta_l1 <= an_delta_l1).to(torch.float32).mean() + swap_ratio_l2 = (ap_delta_l2 <= an_delta_l2).to(torch.float32).mean() + return swap_ratio_l1, swap_ratio_l2 + + +def aggregate_measure_distances_along_factor( + ground_truth_dataset, + representation_function, + f_idx: int, + repeats: int, + batch_size: int, +) -> dict: + # FEED FORWARD, COMPUTE ALL + # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # + measures = [] + for i in range(repeats): + # generate repeated factors, varying one factor over the entire range + zs_traversal = encode_all_along_factor(ground_truth_dataset, representation_function, f_idx=f_idx, batch_size=batch_size) + # check the number of swapped elements along a factor + swap_ratio_l1, swap_ratio_l2 = compute_swap_ratios(zs_traversal[:-2], zs_traversal[1:-1], zs_traversal[2:]) + # correlations -- replace invalid values + corr_matrix = torch.abs(torch_nan_to_num(torch_corr_matrix(zs_traversal), nan=1.0, posinf=1.0, neginf=-1.0)) + rank_corr_matrix = torch.abs(torch_nan_to_num(torch_rank_corr_matrix(zs_traversal), nan=1.0, posinf=1.0, neginf=-1.0)) + # save variables + measures.append({ + 'swap_ratio_l1': swap_ratio_l1, + 'swap_ratio_l2': swap_ratio_l2, + 'ave_corr': torch_tril_mean(corr_matrix), + 'ave_rank_corr': torch_tril_mean(rank_corr_matrix), + }) + + # AGGREGATE DATA - For each distance measure + # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # + measures = default_collate(measures) + return { + 'swap_ratio_l1': measures['swap_ratio_l1'].mean(dim=0), # shape: (repeats,) -> () + 'swap_ratio_l2': measures['swap_ratio_l2'].mean(dim=0), # shape: (repeats,) -> () + 'ave_corr': measures['ave_corr'].mean(dim=0), # shape: (repeats,) -> () + 'ave_rank_corr': measures['ave_rank_corr'].mean(dim=0), # shape: (repeats,) -> () + } + + +# ========================================================================= # +# END # +# ========================================================================= # + + +if __name__ == '__main__': + import pytorch_lightning as pl + from torch.optim import Adam + from torch.utils.data import DataLoader + from disent.data.groundtruth import XYObjectData, XYSquaresData + from disent.dataset.groundtruth import GroundTruthDataset, GroundTruthDatasetPairs + from disent.frameworks.vae.unsupervised import BetaVae + from disent.frameworks.vae.weaklysupervised import AdaVae + from disent.frameworks.vae.supervised import TripletVae + from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder + from disent.transform import ToStandardisedTensor + from disent.util import colors + from disent.util import Timer + + def get_str(r): + return ', '.join(f'{k}={v:6.4f}' for k, v in r.items()) + + def print_r(name, steps, result, clr=colors.lYLW, t: Timer = None): + print(f'{clr}{name:<13} ({steps:>04}){f" {colors.GRY}[{t.pretty}]{clr}" if t else ""}: {get_str(result)}{colors.RST}') + + def calculate(name, steps, dataset, get_repr): + global aggregate_measure_distances_along_factor + with Timer() as t: + r = metric_dual_flatness(dataset, get_repr, factor_repeats=64, batch_size=64) + results.append((name, steps, r)) + print_r(name, steps, r, colors.lRED, t=t) + print(colors.GRY, '='*100, colors.RST, sep='') + return r + + class XYOverlapData(XYSquaresData): + def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3, rgb=True): + if grid_spacing is None: + grid_spacing = (square_size+1) // 2 + super().__init__(square_size=square_size, grid_size=grid_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) + + # datasets = [XYObjectData(rgb=False, palette='white'), XYSquaresData(), XYOverlapData(), XYObjectData()] + datasets = [XYObjectData()] + + results = [] + for data in datasets: + + dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor()) + dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True) + module = AdaVae( + make_optimizer_fn=lambda params: Adam(params, lr=5e-4), + make_model_fn=lambda: AutoEncoder( + encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), + decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), + ), + cfg=AdaVae.cfg(beta=0.001, loss_reduction='mean') + ) + + # dataset = GroundTruthDatasetTriples(data, transform=ToStandardisedTensor(), swap_metric='manhattan') + # dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True) + # module = TripletVae( + # make_optimizer_fn=lambda params: Adam(params, lr=5e-4), + # make_model_fn=lambda: AutoEncoder( + # encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), + # decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), + # ), + # cfg=TripletVae.cfg(beta=0.001, loss_reduction='mean', triplet_p=2, triplet_scale=100) + # ) + + # we cannot guarantee which device the representation is on + get_repr = lambda x: module.encode(x.to(module.device)) + # PHASE 1, UNTRAINED + pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=True, gpus=1, weights_summary=None).fit(module, dataloader) + module = module.to('cuda') + calculate(data.__class__.__name__, 0, dataset, get_repr) + # PHASE 2, LITTLE TRAINING + pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, gpus=1, weights_summary=None).fit(module, dataloader) + calculate(data.__class__.__name__, 256, dataset, get_repr) + # PHASE 3, MORE TRAINING + pl.Trainer(logger=False, checkpoint_callback=False, max_steps=2048, gpus=1, weights_summary=None).fit(module, dataloader) + calculate(data.__class__.__name__, 256+2048, dataset, get_repr) + results.append(None) + + for result in results: + if result is None: + print() + continue + (name, steps, result) = result + print_r(name, steps, result, colors.lYLW) diff --git a/disent/util/math.py b/disent/util/math.py index 049c04b0..a6499171 100644 --- a/disent/util/math.py +++ b/disent/util/math.py @@ -21,6 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import warnings from typing import List from typing import Optional from typing import Union @@ -142,8 +143,13 @@ def torch_tril_mean(mat: torch.Tensor, diagonal=-1): def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[int, str] = 1): """ - Generalised Mean - - is this implementation actually correct? + Compute the generalised mean. + - p is the power + + harmonic mean ≤ geometric mean ≤ arithmetic mean + - If values have the same units: Use the arithmetic mean. + - If values have differing units: Use the geometric mean. + - If values are rates: Use the harmonic mean. """ if isinstance(p, str): p = _GENERALIZED_MEAN_MAP[p] @@ -184,6 +190,22 @@ def torch_mean_harmonic(xs, dim: _DimTypeHint = None): return torch_mean_generalized(xs, dim=dim, p='harmonic') +# ========================================================================= # +# polyfill - in later versions of pytorch # +# ========================================================================= # + + +def torch_nan_to_num(input, nan=0.0, posinf=None, neginf=None): + output = input.clone() + if nan is not None: + output[torch.isnan(input)] = nan + if posinf is not None: + output[input == np.inf] = posinf + if neginf is not None: + output[input == -np.inf] = neginf + return output + + # ========================================================================= # # end # # ========================================================================= # From eea85dc8befd48ca44238bd019fec2758299b0d5 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 19 Mar 2021 02:43:52 +0200 Subject: [PATCH 6/7] dual_flatness metric in configs --- disent/metrics/__init__.py | 27 +++++++++++++++------------ disent/metrics/_dual_flatness.py | 3 +++ experiment/config/metrics/all.yaml | 1 + 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/disent/metrics/__init__.py b/disent/metrics/__init__.py index 5f99d644..24b75836 100644 --- a/disent/metrics/__init__.py +++ b/disent/metrics/__init__.py @@ -31,6 +31,7 @@ # Nathan Michlo et. al from ._flatness import metric_flatness +from ._dual_flatness import metric_dual_flatness # ========================================================================= # @@ -42,19 +43,21 @@ FAST_METRICS = { - 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes - 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds - 'flatness': _wrapped_partial(metric_flatness, factor_repeats=128), - 'mig': _wrapped_partial(metric_mig, num_train=2000), - 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000), - 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000), + 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes + 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds + 'flatness': _wrapped_partial(metric_flatness, factor_repeats=128), + 'dual_flatness': _wrapped_partial(metric_dual_flatness, factor_repeats=128), + 'mig': _wrapped_partial(metric_mig, num_train=2000), + 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000), + 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000), } DEFAULT_METRICS = { - 'dci': metric_dci, - 'factor_vae': metric_factor_vae, - 'flatness': metric_flatness, - 'mig': metric_mig, - 'sap': metric_sap, - 'unsupervised': metric_unsupervised, + 'dci': metric_dci, + 'factor_vae': metric_factor_vae, + 'flatness': metric_flatness, + 'dual_flatness': metric_dual_flatness, + 'mig': metric_mig, + 'sap': metric_sap, + 'unsupervised': metric_unsupervised, } diff --git a/disent/metrics/_dual_flatness.py b/disent/metrics/_dual_flatness.py index 130ec92b..5a11001f 100644 --- a/disent/metrics/_dual_flatness.py +++ b/disent/metrics/_dual_flatness.py @@ -226,6 +226,9 @@ def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3 # datasets = [XYObjectData(rgb=False, palette='white'), XYSquaresData(), XYOverlapData(), XYObjectData()] datasets = [XYObjectData()] + # TODO: fix for dead dimensions + # datasets = [XYObjectData(rgb=False, palette='white')] + results = [] for data in datasets: diff --git a/experiment/config/metrics/all.yaml b/experiment/config/metrics/all.yaml index 0a333a79..ed5bb01f 100644 --- a/experiment/config/metrics/all.yaml +++ b/experiment/config/metrics/all.yaml @@ -1,6 +1,7 @@ # @package _group_ metric_list: - flatness: + - dual_flatness: - mig: - sap: - unsupervised: From 758df37f7e66c11c0e362b48e303584a7813cac1 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 19 Mar 2021 12:37:44 +0200 Subject: [PATCH 7/7] version bump v0.0.1.dev7 --- README.md | 12 ++++++++++-- setup.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c6cad2ac..4855229b 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ The easiest way to use disent is by running `experiements/hydra_system.py` and c ### Features -Disent includes implementations of modules, metrics and datasets from various papers. However modules marked with a "🧵" are newly introduced in disent for [nmichlo](https://github.com/nmichlo)'s MSc. research! +Disent includes implementations of modules, metrics and datasets from various papers. However modules marked with a "🧵" are introduced in disent for [my](https://github.com/nmichlo) MSc. research. #### Frameworks - **Unsupervised**: @@ -109,6 +109,9 @@ Disent includes implementations of modules, metrics and datasets from various pa + [TVAE](https://arxiv.org/abs/1802.04403) - **Experimental**: + 🧵 Ada-TVAE + - Adaptive Triplet VAE + + 🧵 DO-TVE + - Data Overlap Triplet Variational Encoder + *various others not worth mentioning* Many popular disentanglement frameworks still need to be added, please @@ -130,9 +133,14 @@ submit an issue if you have a request for an additional framework. + [SAP](https://arxiv.org/abs/1711.00848) + [Unsupervised Scores](https://github.com/google-research/disentanglement_lib) + 🧵 Flatness Score + - Measures max width over path length of factor traversal embeddings, a combined measure of linearity and ordering. + + 🧵 Dual Flatness - Linearity & Ordering + - Measure **linearity** of factor traversal embeddings using average Pearson's correlation matrices + - Measure **ordering** of factor traversal embedding using average Spearman's rank correlation matrices + - Measure **ordering** of embeddings by checking anchor-positive and anchor-negative distances correspond to ground-truth factors Some popular metrics still need to be added, please submit an issue if you wish to -add your own or you have a request for an additional metric. +add your own, or you have a request.
todo

diff --git a/setup.py b/setup.py index 8dbf830e..7f95d834 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.0.1.dev6", + version="0.0.1.dev7", python_requires="==3.8", packages=setuptools.find_packages(),