Skip to content

Commit

Permalink
fix: add sample dataset and missing enum
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 22, 2024
1 parent bdbaf37 commit df4bbce
Show file tree
Hide file tree
Showing 41 changed files with 24 additions and 3 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added sample_dataset/latent_mean.npy
Binary file not shown.
Binary file added sample_dataset/latent_std.npy
Binary file not shown.
Binary file added sample_dataset/test/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/test/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/test_caption_choices.npy
Binary file not shown.
Binary file added sample_dataset/test_captions.npy
Binary file not shown.
Binary file added sample_dataset/test_labels.npy
Binary file not shown.
Binary file added sample_dataset/test_latent.npy
Binary file not shown.
Binary file added sample_dataset/test_odd_one_out_labels.npy
Binary file not shown.
Binary file added sample_dataset/test_unpaired.npy
Binary file not shown.
Binary file added sample_dataset/train/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/train/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/train/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/train/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/train_caption_choices.npy
Binary file not shown.
Binary file added sample_dataset/train_captions.npy
Binary file not shown.
Binary file added sample_dataset/train_labels.npy
Binary file not shown.
Binary file added sample_dataset/train_latent.npy
Binary file not shown.
Binary file added sample_dataset/train_odd_one_out_labels.npy
Binary file not shown.
Binary file added sample_dataset/train_unpaired.npy
Binary file not shown.
Binary file added sample_dataset/val/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/val/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sample_dataset/val_caption_choices.npy
Binary file not shown.
Binary file added sample_dataset/val_captions.npy
Binary file not shown.
Binary file added sample_dataset/val_labels.npy
Binary file not shown.
Binary file added sample_dataset/val_latent.npy
Binary file not shown.
Binary file added sample_dataset/val_odd_one_out_labels.npy
Binary file not shown.
Binary file added sample_dataset/val_unpaired.npy
Binary file not shown.
27 changes: 24 additions & 3 deletions simple_shapes_dataset/dataset/domain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable, Iterable
from enum import Enum
from pathlib import Path
from typing import Any, NamedTuple, TypedDict

Expand All @@ -7,7 +8,26 @@
from PIL import Image
from shimmer import DataDomain, DomainDesc

from simple_shapes_dataset.types import DomainType

class DomainType(Enum):
v = DomainDesc("v", "v")
v_latents = DomainDesc("v", "v_latents")
attr = DomainDesc("attr", "attr")
t = DomainDesc("t", "t")
raw_text = DomainDesc("t", "raw_text")


class DomainModelVariantType(Enum):
v = (DomainType.v, "default")
attr = (DomainType.attr, "default")
attr_legacy = (DomainType.attr, "legacy")
attr_unpaired = (DomainType.attr, "unpaired")
v_latents = (DomainType.v_latents, "default")
v_latents_unpaired = (DomainType.v_latents, "unpaired")

def __init__(self, kind: DomainType, model_variant: str) -> None:
self.kind = kind
self.model_variant = model_variant


class SimpleShapesImages(DataDomain):
Expand Down Expand Up @@ -305,6 +325,7 @@ def get_default_domains(
) -> dict[DomainDesc, type[DataDomain]]:
domain_classes = {}
for domain in domains:
domain_desc = DomainType[domain].value if isinstance(domain, str) else domain
domain_classes[domain_desc] = DEFAULT_DOMAINS[domain_desc.kind]
if isinstance(domain, str):
domain = DomainType[domain].value
domain_classes[domain] = DEFAULT_DOMAINS[domain.kind]
return domain_classes

0 comments on commit df4bbce

Please sign in to comment.