From 43d37f7eda3d4d08e820fa6ece1955593f6a2256 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:26:39 +0100 Subject: [PATCH] Fix bug in the `ConditionalOTDataset` --- docs/neural/data.rst | 10 +--- src/ott/neural/data/__init__.py | 2 +- src/ott/neural/data/dataloaders.py | 91 ------------------------------ src/ott/neural/data/datasets.py | 87 ++++++++++++++++++++++++++++ tests/neural/conftest.py | 70 +++++++++++------------ 5 files changed, 125 insertions(+), 135 deletions(-) delete mode 100644 src/ott/neural/data/dataloaders.py create mode 100644 src/ott/neural/data/datasets.py diff --git a/docs/neural/data.rst b/docs/neural/data.rst index 970499ff5..95f05f93f 100644 --- a/docs/neural/data.rst +++ b/docs/neural/data.rst @@ -11,11 +11,5 @@ Datasets .. autosummary:: :toctree: _autosummary - dataloaders.OTDataSet - -Dataloaders ------------ -.. autosummary:: - :toctree: _autosummary - - dataloaders.ConditionalOTDataLoader + datasets.OTDataset + datasets.ConditionalOTDataset diff --git a/src/ott/neural/data/__init__.py b/src/ott/neural/data/__init__.py index 51f8dd2af..785604b21 100644 --- a/src/ott/neural/data/__init__.py +++ b/src/ott/neural/data/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import dataloaders +from . import datasets diff --git a/src/ott/neural/data/dataloaders.py b/src/ott/neural/data/dataloaders.py deleted file mode 100644 index 8083a744c..000000000 --- a/src/ott/neural/data/dataloaders.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Mapping, Optional - -import numpy as np -from jax import tree_util - -__all__ = ["OTDataSet", "ConditionalOTDataLoader"] - - -class OTDataSet: - """Data set for OT problems. - - Args: - lin: Linear part of the measure. - quad: Quadratic part of the measure. - conditions: Conditions of the source measure. - """ - - def __init__( - self, - lin: Optional[np.ndarray] = None, - quad: Optional[np.ndarray] = None, - conditions: Optional[np.ndarray] = None, - ): - if lin is not None: - if quad is not None: - assert len(lin) == len(quad) - self.n_samples = len(lin) - else: - self.n_samples = len(lin) - else: - self.n_samples = len(quad) - if conditions is not None: - assert len(conditions) == self.n_samples - - self.lin = lin - self.quad = quad - self.conditions = conditions - self._tree = {} - if lin is not None: - self._tree["lin"] = lin - if quad is not None: - self._tree["quad"] = quad - if conditions is not None: - self._tree["conditions"] = conditions - - def __getitem__(self, idx: np.ndarray) -> Mapping[str, np.ndarray]: - return tree_util.tree_map(lambda x: x[idx], self._tree) - - def __len__(self): - return self.n_samples - - -class ConditionalOTDataLoader: - """Data loader for OT problems with conditions. - - This data loader wraps several data loaders and samples from them. - - Args: - dataloaders: List of data loaders. - seed: Random seed. - """ - - def __init__( - self, - dataloaders: List[Any], - seed: int = 0 # dataloader should subclass torch dataloader - ): - super().__init__() - self.dataloaders = dataloaders - self.conditions = list(dataloaders) - self.rng = np.random.default_rng(seed=seed) - - def __next__(self) -> Mapping[str, np.ndarray]: - idx = self.rng.choice(len(self.conditions)) - return next(iter(self.dataloaders[idx])) - - def __iter__(self) -> "ConditionalOTDataLoader": - return self diff --git a/src/ott/neural/data/datasets.py b/src/ott/neural/data/datasets.py new file mode 100644 index 000000000..990c27a2a --- /dev/null +++ b/src/ott/neural/data/datasets.py @@ -0,0 +1,87 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional + +import jax.tree_util as jtu +import numpy as np + +__all__ = ["OTDataset", "ConditionalOTDataset"] + + +class OTDataset: + """Dataset for Optimal transport problems. + + Args: + lin: Linear part of the measure. + quad: Quadratic part of the measure. + conditions: Conditions of the source measure. + """ + + def __init__( + self, + lin: Optional[np.ndarray] = None, + quad: Optional[np.ndarray] = None, + conditions: Optional[np.ndarray] = None, + ): + self.data = {} + if lin is not None: + self.data["lin"] = lin + if quad is not None: + self.data["quad"] = quad + if conditions is not None: + self.data["conditions"] = conditions + self._check_sizes() + + def _check_sizes(self) -> None: + sizes = {k: len(v) for k, v in self.data.items()} + if not len(set(sizes.values())) == 1: + raise ValueError(f"Not all arrays have the same size: {sizes}.") + + def __getitem__(self, idx: np.ndarray) -> Dict[str, np.ndarray]: + return jtu.tree_map(lambda x: x[idx], self.data)["lin"] + + def __len__(self) -> int: + for v in self.data.values(): + return len(v) + return 0 + + +# TODO(michalk8): rename +class ConditionalOTDataset: + """Dataset for OT problems with conditions. + + This data loader wraps several data loaders and samples from them. + + Args: + datasets: Datasets to sample from. + seed: Random seed. + """ + + def __init__( + self, + # TODO(michalk8): allow for dict with weights + datasets: List[OTDataset], + seed: Optional[int] = None, + ): + self.datasets = tuple(datasets) + self._rng = np.random.default_rng(seed=seed) + self._iterators = () + + def __next__(self) -> Dict[str, np.ndarray]: + idx = self._rng.choice(len(self._iterators)) + return next(self._iterators[idx]) + + def __iter__(self) -> "ConditionalOTDataset": + self._iterators = tuple(iter(ds) for ds in self.datasets) + return self diff --git a/tests/neural/conftest.py b/tests/neural/conftest.py index e40f93c16..f5c48e924 100644 --- a/tests/neural/conftest.py +++ b/tests/neural/conftest.py @@ -17,21 +17,21 @@ import numpy as np import torch -from torch.utils.data import DataLoader as Torch_loader +from torch.utils.data import DataLoader -from ott.neural.data import dataloaders +from ott.neural.data import datasets @pytest.fixture(scope="module") -def data_loaders_gaussian() -> Tuple[Torch_loader, Torch_loader]: +def data_loaders_gaussian() -> Tuple[DataLoader, DataLoader]: """Returns a data loader for a simple Gaussian mixture.""" rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - src_dataset = dataloaders.OTDataSet(lin=source) - tgt_dataset = dataloaders.OTDataSet(lin=target) - loader_src = Torch_loader(src_dataset, batch_size=16, shuffle=True) - loader_tgt = Torch_loader(tgt_dataset, batch_size=16, shuffle=True) + src_dataset = datasets.OTDataset(lin=source) + tgt_dataset = datasets.OTDataset(lin=target) + loader_src = DataLoader(src_dataset, batch_size=16, shuffle=True) + loader_tgt = DataLoader(tgt_dataset, batch_size=16, shuffle=True) return loader_src, loader_tgt @@ -44,22 +44,22 @@ def data_loader_gaussian_conditional(): source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) - 2.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( lin=source_0, target_lin=target_0, conditions=np.zeros_like(source_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( lin=source_1, target_lin=target_1, conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + return datasets.ConditionalOTDataset((dl0, dl1)) @pytest.fixture(scope="module") @@ -71,13 +71,13 @@ def data_loader_gaussian_with_conditions(): source_conditions = rng.normal(size=(100, 1)) target_conditions = rng.normal(size=(100, 1)) - 1.0 - dataset = dataloaders.OTDataSet( + dataset = datasets.OTDataset( lin=source, target_lin=target, conditions=source_conditions, target_conditions=target_conditions ) - return Torch_loader(dataset, batch_size=16, shuffle=True) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -86,8 +86,8 @@ def genot_data_loader_linear(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 2)) + 1.0 - dataset = dataloaders.OTDataSet(lin=source, target_lin=target) - return Torch_loader(dataset, batch_size=16, shuffle=True) + dataset = datasets.OTDataset(lin=source, target_lin=target) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -98,22 +98,22 @@ def genot_data_loader_linear_conditional(): target_0 = rng.normal(size=(100, 2)) + 1.0 source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 2)) + 1.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( lin=source_0, target_lin=target_0, conditions=np.zeros_like(source_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( lin=source_1, target_lin=target_1, conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + return datasets.ConditionalOTDataset((dl0, dl1)) @pytest.fixture(scope="module") @@ -122,8 +122,8 @@ def genot_data_loader_quad(): rng = np.random.default_rng(seed=0) source = rng.normal(size=(100, 2)) target = rng.normal(size=(100, 1)) + 1.0 - dataset = dataloaders.OTDataSet(quad=source, target_quad=target) - return Torch_loader(dataset, batch_size=16, shuffle=True) + dataset = datasets.OTDataset(quad=source, target_quad=target) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -134,22 +134,22 @@ def genot_data_loader_quad_conditional(): target_0 = rng.normal(size=(100, 1)) + 1.0 source_1 = rng.normal(size=(100, 2)) target_1 = rng.normal(size=(100, 1)) + 1.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( quad=source_0, target_quad=target_0, conditions=np.zeros_like(source_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( quad=source_1, target_quad=target_1, conditions=np.ones_like(source_1) * 1.0 ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + return datasets.ConditionalOTDataset((dl0, dl1)) @pytest.fixture(scope="module") @@ -160,13 +160,13 @@ def genot_data_loader_fused(): target_q = rng.normal(size=(100, 1)) + 1.0 source_lin = rng.normal(size=(100, 2)) target_lin = rng.normal(size=(100, 2)) + 1.0 - dataset = dataloaders.OTDataSet( + dataset = datasets.OTDataset( lin=source_lin, quad=source_q, target_lin=target_lin, target_quad=target_q ) - return Torch_loader(dataset, batch_size=16, shuffle=True) + return DataLoader(dataset, batch_size=16, shuffle=True) @pytest.fixture(scope="module") @@ -183,14 +183,14 @@ def genot_data_loader_fused_conditional(): source_lin_1 = 2 * rng.normal(size=(100, 2)) target_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0 - ds0 = dataloaders.OTDataSet( + ds0 = datasets.OTDataset( lin=source_lin_0, target_lin=target_lin_0, quad=source_q_0, target_quad=target_q_0, conditions=np.zeros_like(source_lin_0) * 0.0 ) - ds1 = dataloaders.OTDataSet( + ds1 = datasets.OTDataset( lin=source_lin_1, target_lin=target_lin_1, quad=source_q_1, @@ -199,6 +199,6 @@ def genot_data_loader_fused_conditional(): ) sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True) sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True) - dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0) - dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1) - return dataloaders.ConditionalOTDataLoader((dl0, dl1)) + dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0) + dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1) + return datasets.ConditionalOTDataset((dl0, dl1))