Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Mar 19, 2021
2 parents 85db681 + 758df37 commit 787dd8f
Show file tree
Hide file tree
Showing 14 changed files with 851 additions and 54 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ The easiest way to use disent is by running `experiements/hydra_system.py` and c

### Features

Disent includes implementations of modules, metrics and datasets from various papers. However modules marked with a "🧵" are newly introduced in disent for [nmichlo](https://github.com/nmichlo)'s MSc. research!
Disent includes implementations of modules, metrics and datasets from various papers. However modules marked with a "🧵" are introduced in disent for [my](https://github.com/nmichlo) MSc. research.

#### Frameworks
- **Unsupervised**:
Expand All @@ -109,6 +109,9 @@ Disent includes implementations of modules, metrics and datasets from various pa
+ [TVAE](https://arxiv.org/abs/1802.04403)
- **Experimental**:
+ 🧵 Ada-TVAE
- Adaptive Triplet VAE
+ 🧵 DO-TVE
- Data Overlap Triplet Variational Encoder
+ *various others not worth mentioning*

Many popular disentanglement frameworks still need to be added, please
Expand All @@ -130,9 +133,14 @@ submit an issue if you have a request for an additional framework.
+ [SAP](https://arxiv.org/abs/1711.00848)
+ [Unsupervised Scores](https://github.com/google-research/disentanglement_lib)
+ 🧵 Flatness Score
- Measures max width over path length of factor traversal embeddings, a combined measure of linearity and ordering.
+ 🧵 Dual Flatness - Linearity & Ordering
- Measure **linearity** of factor traversal embeddings using average Pearson's correlation matrices
- Measure **ordering** of factor traversal embedding using average Spearman's rank correlation matrices
- Measure **ordering** of embeddings by checking anchor-positive and anchor-negative distances correspond to ground-truth factors

Some popular metrics still need to be added, please submit an issue if you wish to
add your own or you have a request for an additional metric.
add your own, or you have a request.

<details><summary><b>todo</b></summary><p>

Expand Down
41 changes: 31 additions & 10 deletions disent/data/util/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def pos_to_idx(self, positions) -> np.ndarray:
- positions are lists of integers, with each element < their corresponding factor size
- indices are integers < size
"""
positions = np.array(positions).T
positions = np.moveaxis(positions, source=-1, destination=0)
return np.ravel_multi_index(positions, self._factor_sizes)

def idx_to_pos(self, indices) -> np.ndarray:
Expand All @@ -92,7 +92,7 @@ def idx_to_pos(self, indices) -> np.ndarray:
- positions are lists of integers, with each element < their corresponding factor size
"""
positions = np.unravel_index(indices, self._factor_sizes)
return np.array(positions).T
return np.moveaxis(positions, source=0, destination=-1)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Sampling Functions - any dim array, only last axis counts! #
Expand Down Expand Up @@ -157,19 +157,40 @@ def resample_factors(self, factors, fixed_factor_indices) -> np.ndarray:
"""
return self.sample_missing_factors(np.array(factors)[..., fixed_factor_indices], fixed_factor_indices)

def sample_random_traversal_factors(self, f_idx: int = None) -> np.ndarray:
def _get_f_idx_and_factors_and_size(self, f_idx: int = None, factors=None, num: int = None):
# choose a random factor if not given
if f_idx is None:
f_idx = np.random.randint(0, self.num_factors)
f_size = self.factor_sizes[f_idx]
# Aka. a traversal along a single factor
# make sequential factors, one randomly sampled list of
# factors, then repeated, with one index mutated as if set by range()
factors = self.sample_factors(size=1)
factors = factors.repeat(f_size, axis=0)
# sample factors if not given
if factors is None:
factors = self.sample_factors(size=1)
else:
factors = factors.reshape((1, self.num_factors))
# get size if not given
if num is None:
num = self.factor_sizes[f_idx]
else:
assert num > 0
# generate a traversal
factors = factors.repeat(num, axis=0)
# return everything
return f_idx, factors, num

def sample_random_traversal_factors(self, f_idx: int = None, factors=None) -> np.ndarray:
f_idx, factors, f_size = self._get_f_idx_and_factors_and_size(f_idx=f_idx, factors=factors, num=None)
# generate traversal
factors[:, f_idx] = np.arange(f_size)
# return factors
return factors

def sample_random_cycle_factors(self, f_idx: int = None, factors=None, num: int = None):
f_idx, factors, num = self._get_f_idx_and_factors_and_size(f_idx=f_idx, factors=factors, num=num)
# generate traversal
grid = np.linspace(0, self.factor_sizes[f_idx]-1, num=num, endpoint=True)
grid = np.int64(np.around(grid))
factors[:, f_idx] = grid
# return factors
return factors

# ========================================================================= #
# Hidden State Space #
Expand Down Expand Up @@ -269,4 +290,4 @@ def sample_random_traversal_factors(self, f_idx: int = None) -> np.ndarray:
# """
# get the original index of factors
# """
# return self._state_to_orig_idx[self._states.pos_to_idx(factors)]
# return self._state_to_orig_idx[self._states.pos_to_idx(factors)]
24 changes: 2 additions & 22 deletions disent/frameworks/vae/unsupervised/_dipvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from disent.frameworks.helper.util import compute_ave_loss_and_logs
from disent.frameworks.vae.unsupervised._betavae import BetaVae
from disent.util.math import torch_cov_matrix


# ========================================================================= #
Expand Down Expand Up @@ -114,7 +115,7 @@ def _dip_compute_regulariser(self, cov_matrix):
def _dip_estimate_cov_matrix(self, d_posterior: Normal):
z_mean, z_var = d_posterior.mean, d_posterior.variance
# compute covariance over batch
cov_z_mean = estimate_covariance(z_mean)
cov_z_mean = torch_cov_matrix(z_mean)
# compute covariance matrix based on mode
if self.cfg.dip_mode == "i":
cov_matrix = cov_z_mean
Expand All @@ -128,27 +129,6 @@ def _dip_estimate_cov_matrix(self, d_posterior: Normal):
return cov_matrix


# ========================================================================= #
# Helper #
# ========================================================================= #


def estimate_covariance(xs):
"""
Calculate the covariance of multivariate random variable from samples
over a batch (eg. z_mean(s) calculated from minibatch with shape (BxZ))
- Reference: https://github.com/paruby/DIP-VAE/blob/master/dip_vae.py
"""
# E[mu mu.T]
E_x_x_t = torch.mean(xs.unsqueeze(2) * xs.unsqueeze(1), dim=0)
# E[mu] (mean of distributions)
E_x = torch.mean(xs, dim=0)
# covariance matrix of model mean
cov_x = E_x_x_t - (E_x.unsqueeze(1) * E_x.unsqueeze(0))
# done!
return cov_x


# ========================================================================= #
# END #
# ========================================================================= #
27 changes: 15 additions & 12 deletions disent/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

# Nathan Michlo et. al
from ._flatness import metric_flatness
from ._dual_flatness import metric_dual_flatness


# ========================================================================= #
Expand All @@ -42,19 +43,21 @@


FAST_METRICS = {
'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes
'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds
'flatness': _wrapped_partial(metric_flatness, factor_repeats=128),
'mig': _wrapped_partial(metric_mig, num_train=2000),
'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000),
'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000),
'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes
'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds
'flatness': _wrapped_partial(metric_flatness, factor_repeats=128),
'dual_flatness': _wrapped_partial(metric_dual_flatness, factor_repeats=128),
'mig': _wrapped_partial(metric_mig, num_train=2000),
'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000),
'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000),
}

DEFAULT_METRICS = {
'dci': metric_dci,
'factor_vae': metric_factor_vae,
'flatness': metric_flatness,
'mig': metric_mig,
'sap': metric_sap,
'unsupervised': metric_unsupervised,
'dci': metric_dci,
'factor_vae': metric_factor_vae,
'flatness': metric_flatness,
'dual_flatness': metric_dual_flatness,
'mig': metric_mig,
'sap': metric_sap,
'unsupervised': metric_unsupervised,
}
Loading

0 comments on commit 787dd8f

Please sign in to comment.