Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed May 26, 2021
2 parents 46505de + f80dca2 commit 2ec3b5a
Show file tree
Hide file tree
Showing 132 changed files with 3,561 additions and 1,555 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ jobs:
- name: Install dependencies
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements-test.txt
python3 -m pip install -r requirements-pre.txt
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-test.txt
- name: Test with pytest
run: |
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Please use the following citation if you use Disent in your research:

**WARNING**: Disent is still under active development. Features and APIs are not considered stable, but should be expected to change! A very limited set of tests currently exist which will be expanded upon in time.

The easiest way to use disent is by running `experiements/hydra_system.py` and changing the root config in `experiements/config/config.yaml`. Configurations are managed with [Hydra Config](https://github.com/facebookresearch/hydra)
The easiest way to use disent is by running `experiement/run.py` and changing the root config in `experiements/config/config.yaml`. Configurations are managed with [Hydra Config](https://github.com/facebookresearch/hydra)

**Pypi**:

Expand All @@ -88,8 +88,8 @@ The easiest way to use disent is by running `experiements/hydra_system.py` and c

3. Install the requirements for python 3.8 with `pip3 install -r requirements.txt`

4. Run the default experiment after configuring `experiments/config/config.yaml`
by running `PYTHONPATH=. python3 experiments/run.py`
4. Run the default experiment after configuring `experiment/config/config.yaml`
by running `PYTHONPATH=. python3 experiment/run.py`

----------------------

Expand Down Expand Up @@ -231,7 +231,7 @@ from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.data.groundtruth import XYObjectData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae.unsupervised import BetaVae
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder
from disent.schedule import CyclicSchedule
Expand Down
9 changes: 6 additions & 3 deletions disent/data/groundtruth/_shapes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def __init__(self, data_dir='data/dataset/3dshapes', in_memory=False, force_down
# ========================================================================= #


if __name__ == '__main__':
dataset = Shapes3dData(data_dir='data/dataset/shapes3d-1-64-64-3')
# pair_dataset = PairedVariationDataset(dataset, k='uniform')
# if __name__ == '__main__':
# dataset = RandomDataset(Shapes3dData())
# dataloader = DataLoader(dataset, num_workers=os.cpu_count(), batch_size=256)
#
# for batch in tqdm(dataloader):
# pass

# # test that dimensions are resampled correctly, and only differ by a certain number of factors, not all.
# for i in range(10):
Expand Down
6 changes: 4 additions & 2 deletions disent/data/groundtruth/_xysquares.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from typing import Tuple
from disent.data.groundtruth.base import GroundTruthData
import numpy as np
from disent.util import chunked

from disent.util import iter_chunks


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,7 +95,7 @@ def __getitem__(self, idx):
offset, space, size = self._offset, self._spacing, self._square_size
# GENERATE
obs = np.zeros(self.observation_shape, dtype=np.uint8)
for i, (fx, fy) in enumerate(chunked(factors, 2)):
for i, (fx, fy) in enumerate(iter_chunks(factors, 2)):
x, y = offset + space * fx, offset + space * fy
if self._rgb:
obs[y:y+size, x:x+size, i] = self._fill_value
Expand Down
52 changes: 32 additions & 20 deletions disent/data/groundtruth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _do_download_dataset(self):
no_data = not os.path.exists(path)
# download data
if self._force_download or no_data:
download_file(url, path)
download_file(url, path, overwrite_existing=True)

@property
def dataset_paths(self) -> List[str]:
Expand Down Expand Up @@ -134,7 +134,7 @@ def _do_download_and_process_dataset(self):
do_data = self._force_download or (no_data and do_proc)

if do_data:
download_file(self.dataset_url, self._data_path)
download_file(self.dataset_url, self._data_path, overwrite_existing=True)

if do_proc:
# TODO: also used in io save file, convert to with syntax.
Expand Down Expand Up @@ -182,8 +182,21 @@ def _preprocess_dataset(self, path_src, path_dst):

class Hdf5PreprocessedGroundTruthData(PreprocessedDownloadableGroundTruthData, metaclass=ABCMeta):
"""
Automatically download and pre-process an hdf5 dataset into the specific chunk sizes.
TODO: Only supports one dataset from the hdf5 file itself, labels etc need a custom implementation.
Automatically download and pre-process an hdf5 dataset
into the specific chunk sizes.
Often the (non-chunked) dataset will be optimized for random accesses,
while the unprocessed (chunked) dataset will be better for sequential reads.
- The chunk size specifies the region of data to be loaded when accessing a
single element of the dataset, if the chunk size is not correctly set,
unneeded data will be loaded when accessing observations.
- override `hdf5_chunk_size` to set the chunk size, for random access
optimized data this should be set to the minimum observation shape that can
be broadcast across the shape of the dataset. Eg. with observations of shape
(64, 64, 3), set the chunk size to (1, 64, 64, 3).
TODO: Only supports one dataset from the hdf5 file
itself, labels etc need a custom implementation.
"""

def __init__(self, data_dir='data/dataset', in_memory=False, force_download=False, force_preprocess=False):
Expand All @@ -192,26 +205,25 @@ def __init__(self, data_dir='data/dataset', in_memory=False, force_download=Fals

# Load the entire dataset into memory if required
if self._in_memory:
# Only load the dataset once, no matter how many instances of the class are created.
# data is stored on the underlying class at the _DATA property.
# TODO: this is weird
if not hasattr(self.__class__, '_DATA'):
log.info(f'[DATASET: {self.__class__.__name__}]: Loading...')
# Often the (non-chunked) dataset will be optimized for random accesses,
# while the unprocessed (chunked) dataset will be better for sequential reads.
with h5py.File(self.dataset_path, 'r') as db:
# indexing dataset objects returns numpy array
# instantiating np.array from the dataset requires double memory.
self.__class__._DATA = db[self.hdf5_name][:]
log.info(f'[DATASET: {self.__class__.__name__}]: Loaded!')
with h5py.File(self.dataset_path, 'r', libver='latest', swmr=True) as db:
# indexing dataset objects returns numpy array
# instantiating np.array from the dataset requires double memory.
self._memory_data = db[self.hdf5_name][:]
else:
# is this thread safe?
self._hdf5_file = h5py.File(self.dataset_path, 'r', libver='latest', swmr=True)
self._hdf5_data = self._hdf5_file[self.hdf5_name]

def __getitem__(self, idx):
if self._in_memory:
return self.__class__._DATA[idx]
return self._memory_data[idx]
else:
# This actually doesnt seem too slow
with h5py.File(self.dataset_path, 'r', libver='latest', swmr=True) as f:
return f[self.hdf5_name][idx]
return self._hdf5_data[idx]

def __del__(self):
# do we need to do this?
if not self._in_memory:
self._hdf5_file.close()

def _preprocess_dataset(self, path_src, path_dst):
import os
Expand Down
9 changes: 7 additions & 2 deletions disent/data/util/in_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
import warnings


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,8 +81,11 @@ def download_file(url, save_path=None, overwrite_existing=False, chunk_size=4096
raise Exception('Invalid save path: "{save_path}"')

# check save path isnt there
if not overwrite_existing and os.path.isfile(save_path):
raise Exception(f'File already exists: "{save_path}" set overwrite_existing=True to overwrite.')
if os.path.isfile(save_path):
if overwrite_existing:
warnings.warn(f'Overwriting existing file: "{save_path}"')
else:
raise Exception(f'File already exists: "{save_path}" set overwrite_existing=True to overwrite.')

# we download to a temporary file in case there is an error
temp_download_path = os.path.join(path_dir, f'.{path_base}.download.temp')
Expand Down
35 changes: 17 additions & 18 deletions disent/data/util/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np
from disent.util import LengthIter
from disent.visualize.visualize_util import get_factor_traversal


# ========================================================================= #
Expand Down Expand Up @@ -157,40 +158,38 @@ def resample_factors(self, factors, fixed_factor_indices) -> np.ndarray:
"""
return self.sample_missing_factors(np.array(factors)[..., fixed_factor_indices], fixed_factor_indices)

def _get_f_idx_and_factors_and_size(self, f_idx: int = None, factors=None, num: int = None):
def _get_f_idx_and_factors_and_size(self, f_idx: int = None, base_factors=None, num: int = None):
"""
:param f_idx: Sampled randomly in the range [0, num_factors) if not given.
:param base_factors: Sampled randomly from all possible factors if not given. Coerced into the shape (1, num_factors)
:param num: Set to the factor size `self.factor_sizes[f_idx]` if not given.
:return: All values above in a tuple.
"""
# choose a random factor if not given
if f_idx is None:
f_idx = np.random.randint(0, self.num_factors)
# sample factors if not given
if factors is None:
factors = self.sample_factors(size=1)
if base_factors is None:
base_factors = self.sample_factors(size=1)
else:
factors = factors.reshape((1, self.num_factors))
base_factors = np.reshape(base_factors, (1, self.num_factors))
# get size if not given
if num is None:
num = self.factor_sizes[f_idx]
else:
assert num > 0
# generate a traversal
factors = factors.repeat(num, axis=0)
base_factors = base_factors.repeat(num, axis=0)
# return everything
return f_idx, factors, num
return f_idx, base_factors, num

def sample_random_traversal_factors(self, f_idx: int = None, factors=None) -> np.ndarray:
f_idx, factors, f_size = self._get_f_idx_and_factors_and_size(f_idx=f_idx, factors=factors, num=None)
def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode='interval') -> np.ndarray:
f_idx, base_factors, num = self._get_f_idx_and_factors_and_size(f_idx=f_idx, base_factors=base_factors, num=num)
# generate traversal
factors[:, f_idx] = np.arange(f_size)
base_factors[:, f_idx] = get_factor_traversal(self.factor_sizes[f_idx], num_frames=num, mode=mode)
# return factors
return factors
return base_factors

def sample_random_cycle_factors(self, f_idx: int = None, factors=None, num: int = None):
f_idx, factors, num = self._get_f_idx_and_factors_and_size(f_idx=f_idx, factors=factors, num=num)
# generate traversal
grid = np.linspace(0, self.factor_sizes[f_idx]-1, num=num, endpoint=True)
grid = np.int64(np.around(grid))
factors[:, f_idx] = grid
# return factors
return factors

# ========================================================================= #
# Hidden State Space #
Expand Down
55 changes: 40 additions & 15 deletions disent/dataset/_augment_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,49 +53,74 @@ def __len__(self):

def _datapoint_raw_to_target(self, dat):
x_targ = dat
if self.transform:
if self.transform is not None:
x_targ = self.transform(x_targ)
return x_targ

def _datapoint_target_to_input(self, x_targ):
x = x_targ
if self.augment:
if self.augment is not None:
x = self.augment(x)
x = _batch_to_observation(batch=x, obs_shape=x_targ.shape)
return x

def dataset_get(self, idx, mode: str):
"""
Gets the specified datapoint, using the specified mode.
- raw: direct untransformed/unaugmented observations
- target: transformed observations
- input: transformed then augmented observations
- pair: (input, target) tuple of observations
Pipeline:
1. raw = dataset[idx]
2. target = transform(raw)
3. input = augment(target) = augment(transform(raw))
:param idx: The index of the datapoint in the dataset
:param mode: {'raw', 'target', 'input', 'pair'}
:return: observation depending on mode
"""
try:
idx = int(idx)
except:
raise TypeError(f'Indices must be integer-like ({type(idx)}): {idx}')
# we do not support indexing by lists
dat = self._get_augmentable_observation(idx)
x_raw = self._get_augmentable_observation(idx)
# return correct data
if mode == 'pair':
x_targ = self._datapoint_raw_to_target(dat)
x = self._datapoint_target_to_input(x_targ)
x_targ = self._datapoint_raw_to_target(x_raw) # applies self.transform
x = self._datapoint_target_to_input(x_targ) # applies self.augment
return x, x_targ
elif mode == 'input':
x_targ = self._datapoint_raw_to_target(dat)
return self._datapoint_target_to_input(x_targ)
x_targ = self._datapoint_raw_to_target(x_raw) # applies self.transform
x = self._datapoint_target_to_input(x_targ) # applies self.augment
return x
elif mode == 'target':
return self._datapoint_raw_to_target(dat)
x_targ = self._datapoint_raw_to_target(x_raw) # applies self.transform
return x_targ
elif mode == 'raw':
return dat
return x_raw
else:
raise KeyError(f'Invalid {mode=}')
raise ValueError(f'Invalid {mode=}')

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Multiple Datapoints #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

def dataset_get_observation(self, *idxs):
xs, x_targs = zip(*[self.dataset_get(idx, mode='pair') for idx in idxs])
return {
'x': tuple(xs),
'x_targ': tuple(x_targs),
}
xs, xs_targ = zip(*(self.dataset_get(idx, mode='pair') for idx in idxs))
# handle cases
if self.augment is None:
# makes 5-10% faster
return {
'x_targ': xs_targ,
}
else:
return {
'x': xs,
'x_targ': xs_targ,
}

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Batches #
Expand Down
2 changes: 1 addition & 1 deletion disent/dataset/groundtruth/_pair_weak.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np
from disent.data.groundtruth.base import GroundTruthData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.dataset.groundtruth._single import GroundTruthDataset


class GroundTruthDatasetOrigWeakPairs(GroundTruthDataset):
Expand Down
7 changes: 3 additions & 4 deletions disent/dataset/groundtruth/_random_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import numpy as np
from torch.utils.data import DataLoader

from disent.data.groundtruth import GroundTruthData
from disent.data.groundtruth import XYSquaresData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.data.groundtruth.base import GroundTruthData
from disent.dataset.groundtruth._single import GroundTruthDataset


class GroundTruthDistDataset(GroundTruthDataset):
Expand Down
2 changes: 1 addition & 1 deletion disent/dataset/groundtruth/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
from typing import Tuple, Optional
from typing import Tuple
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
Expand Down
2 changes: 1 addition & 1 deletion disent/dataset/groundtruth/_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import logging
import numpy as np
from disent.data.groundtruth.base import GroundTruthData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.dataset.groundtruth._single import GroundTruthDataset


log = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 2ec3b5a

Please sign in to comment.