Skip to content

Commit

Permalink
Fix bug in the ConditionalOTDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Feb 19, 2024
1 parent 7919051 commit 43d37f7
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 135 deletions.
10 changes: 2 additions & 8 deletions docs/neural/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,5 @@ Datasets
.. autosummary::
:toctree: _autosummary

dataloaders.OTDataSet

Dataloaders
-----------
.. autosummary::
:toctree: _autosummary

dataloaders.ConditionalOTDataLoader
datasets.OTDataset
datasets.ConditionalOTDataset
2 changes: 1 addition & 1 deletion src/ott/neural/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import dataloaders
from . import datasets
91 changes: 0 additions & 91 deletions src/ott/neural/data/dataloaders.py

This file was deleted.

87 changes: 87 additions & 0 deletions src/ott/neural/data/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional

import jax.tree_util as jtu
import numpy as np

__all__ = ["OTDataset", "ConditionalOTDataset"]


class OTDataset:
"""Dataset for Optimal transport problems.
Args:
lin: Linear part of the measure.
quad: Quadratic part of the measure.
conditions: Conditions of the source measure.
"""

def __init__(
self,
lin: Optional[np.ndarray] = None,
quad: Optional[np.ndarray] = None,
conditions: Optional[np.ndarray] = None,
):
self.data = {}
if lin is not None:
self.data["lin"] = lin
if quad is not None:
self.data["quad"] = quad
if conditions is not None:
self.data["conditions"] = conditions
self._check_sizes()

def _check_sizes(self) -> None:
sizes = {k: len(v) for k, v in self.data.items()}
if not len(set(sizes.values())) == 1:
raise ValueError(f"Not all arrays have the same size: {sizes}.")

def __getitem__(self, idx: np.ndarray) -> Dict[str, np.ndarray]:
return jtu.tree_map(lambda x: x[idx], self.data)["lin"]

def __len__(self) -> int:
for v in self.data.values():
return len(v)
return 0


# TODO(michalk8): rename
class ConditionalOTDataset:
"""Dataset for OT problems with conditions.
This data loader wraps several data loaders and samples from them.
Args:
datasets: Datasets to sample from.
seed: Random seed.
"""

def __init__(
self,
# TODO(michalk8): allow for dict with weights
datasets: List[OTDataset],
seed: Optional[int] = None,
):
self.datasets = tuple(datasets)
self._rng = np.random.default_rng(seed=seed)
self._iterators = ()

def __next__(self) -> Dict[str, np.ndarray]:
idx = self._rng.choice(len(self._iterators))
return next(self._iterators[idx])

def __iter__(self) -> "ConditionalOTDataset":
self._iterators = tuple(iter(ds) for ds in self.datasets)
return self
70 changes: 35 additions & 35 deletions tests/neural/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@

import numpy as np
import torch
from torch.utils.data import DataLoader as Torch_loader
from torch.utils.data import DataLoader

from ott.neural.data import dataloaders
from ott.neural.data import datasets


@pytest.fixture(scope="module")
def data_loaders_gaussian() -> Tuple[Torch_loader, Torch_loader]:
def data_loaders_gaussian() -> Tuple[DataLoader, DataLoader]:
"""Returns a data loader for a simple Gaussian mixture."""
rng = np.random.default_rng(seed=0)
source = rng.normal(size=(100, 2))
target = rng.normal(size=(100, 2)) + 1.0
src_dataset = dataloaders.OTDataSet(lin=source)
tgt_dataset = dataloaders.OTDataSet(lin=target)
loader_src = Torch_loader(src_dataset, batch_size=16, shuffle=True)
loader_tgt = Torch_loader(tgt_dataset, batch_size=16, shuffle=True)
src_dataset = datasets.OTDataset(lin=source)
tgt_dataset = datasets.OTDataset(lin=target)
loader_src = DataLoader(src_dataset, batch_size=16, shuffle=True)
loader_tgt = DataLoader(tgt_dataset, batch_size=16, shuffle=True)
return loader_src, loader_tgt


Expand All @@ -44,22 +44,22 @@ def data_loader_gaussian_conditional():

source_1 = rng.normal(size=(100, 2))
target_1 = rng.normal(size=(100, 2)) - 2.0
ds0 = dataloaders.OTDataSet(
ds0 = datasets.OTDataset(
lin=source_0,
target_lin=target_0,
conditions=np.zeros_like(source_0) * 0.0
)
ds1 = dataloaders.OTDataSet(
ds1 = datasets.OTDataset(
lin=source_1,
target_lin=target_1,
conditions=np.ones_like(source_1) * 1.0
)
sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True)
sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True)
dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0)
dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1)
dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0)
dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1)

return dataloaders.ConditionalOTDataLoader((dl0, dl1))
return datasets.ConditionalOTDataset((dl0, dl1))


@pytest.fixture(scope="module")
Expand All @@ -71,13 +71,13 @@ def data_loader_gaussian_with_conditions():
source_conditions = rng.normal(size=(100, 1))
target_conditions = rng.normal(size=(100, 1)) - 1.0

dataset = dataloaders.OTDataSet(
dataset = datasets.OTDataset(
lin=source,
target_lin=target,
conditions=source_conditions,
target_conditions=target_conditions
)
return Torch_loader(dataset, batch_size=16, shuffle=True)
return DataLoader(dataset, batch_size=16, shuffle=True)


@pytest.fixture(scope="module")
Expand All @@ -86,8 +86,8 @@ def genot_data_loader_linear():
rng = np.random.default_rng(seed=0)
source = rng.normal(size=(100, 2))
target = rng.normal(size=(100, 2)) + 1.0
dataset = dataloaders.OTDataSet(lin=source, target_lin=target)
return Torch_loader(dataset, batch_size=16, shuffle=True)
dataset = datasets.OTDataset(lin=source, target_lin=target)
return DataLoader(dataset, batch_size=16, shuffle=True)


@pytest.fixture(scope="module")
Expand All @@ -98,22 +98,22 @@ def genot_data_loader_linear_conditional():
target_0 = rng.normal(size=(100, 2)) + 1.0
source_1 = rng.normal(size=(100, 2))
target_1 = rng.normal(size=(100, 2)) + 1.0
ds0 = dataloaders.OTDataSet(
ds0 = datasets.OTDataset(
lin=source_0,
target_lin=target_0,
conditions=np.zeros_like(source_0) * 0.0
)
ds1 = dataloaders.OTDataSet(
ds1 = datasets.OTDataset(
lin=source_1,
target_lin=target_1,
conditions=np.ones_like(source_1) * 1.0
)
sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True)
sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True)
dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0)
dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1)
dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0)
dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1)

return dataloaders.ConditionalOTDataLoader((dl0, dl1))
return datasets.ConditionalOTDataset((dl0, dl1))


@pytest.fixture(scope="module")
Expand All @@ -122,8 +122,8 @@ def genot_data_loader_quad():
rng = np.random.default_rng(seed=0)
source = rng.normal(size=(100, 2))
target = rng.normal(size=(100, 1)) + 1.0
dataset = dataloaders.OTDataSet(quad=source, target_quad=target)
return Torch_loader(dataset, batch_size=16, shuffle=True)
dataset = datasets.OTDataset(quad=source, target_quad=target)
return DataLoader(dataset, batch_size=16, shuffle=True)


@pytest.fixture(scope="module")
Expand All @@ -134,22 +134,22 @@ def genot_data_loader_quad_conditional():
target_0 = rng.normal(size=(100, 1)) + 1.0
source_1 = rng.normal(size=(100, 2))
target_1 = rng.normal(size=(100, 1)) + 1.0
ds0 = dataloaders.OTDataSet(
ds0 = datasets.OTDataset(
quad=source_0,
target_quad=target_0,
conditions=np.zeros_like(source_0) * 0.0
)
ds1 = dataloaders.OTDataSet(
ds1 = datasets.OTDataset(
quad=source_1,
target_quad=target_1,
conditions=np.ones_like(source_1) * 1.0
)
sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True)
sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True)
dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0)
dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1)
dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0)
dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1)

return dataloaders.ConditionalOTDataLoader((dl0, dl1))
return datasets.ConditionalOTDataset((dl0, dl1))


@pytest.fixture(scope="module")
Expand All @@ -160,13 +160,13 @@ def genot_data_loader_fused():
target_q = rng.normal(size=(100, 1)) + 1.0
source_lin = rng.normal(size=(100, 2))
target_lin = rng.normal(size=(100, 2)) + 1.0
dataset = dataloaders.OTDataSet(
dataset = datasets.OTDataset(
lin=source_lin,
quad=source_q,
target_lin=target_lin,
target_quad=target_q
)
return Torch_loader(dataset, batch_size=16, shuffle=True)
return DataLoader(dataset, batch_size=16, shuffle=True)


@pytest.fixture(scope="module")
Expand All @@ -183,14 +183,14 @@ def genot_data_loader_fused_conditional():
source_lin_1 = 2 * rng.normal(size=(100, 2))
target_lin_1 = 2 * rng.normal(size=(100, 2)) + 1.0

ds0 = dataloaders.OTDataSet(
ds0 = datasets.OTDataset(
lin=source_lin_0,
target_lin=target_lin_0,
quad=source_q_0,
target_quad=target_q_0,
conditions=np.zeros_like(source_lin_0) * 0.0
)
ds1 = dataloaders.OTDataSet(
ds1 = datasets.OTDataset(
lin=source_lin_1,
target_lin=target_lin_1,
quad=source_q_1,
Expand All @@ -199,6 +199,6 @@ def genot_data_loader_fused_conditional():
)
sampler0 = torch.utils.data.RandomSampler(ds0, replacement=True)
sampler1 = torch.utils.data.RandomSampler(ds1, replacement=True)
dl0 = Torch_loader(ds0, batch_size=16, sampler=sampler0)
dl1 = Torch_loader(ds1, batch_size=16, sampler=sampler1)
return dataloaders.ConditionalOTDataLoader((dl0, dl1))
dl0 = DataLoader(ds0, batch_size=16, sampler=sampler0)
dl1 = DataLoader(ds1, batch_size=16, sampler=sampler1)
return datasets.ConditionalOTDataset((dl0, dl1))

0 comments on commit 43d37f7

Please sign in to comment.