Skip to content

Commit

Permalink
Rename CoresetSupportingModule. Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
francescodeaglio committed Sep 15, 2023
1 parent 66fb9f1 commit 1a0029c
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 17 deletions.
24 changes: 22 additions & 2 deletions integrationtests/selector/integrationtest_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,8 @@ def test_get_available_labels(reset_after_trigger: bool):
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

assert len(available_labels) == 2
assert 0 in available_labels and 1 in available_labels
# here we expect to have 0 labels since it's before the first trigger
assert len(available_labels) == 0

selector.inform_data_and_trigger(
DataInformRequest(
Expand All @@ -922,6 +922,14 @@ def test_get_available_labels(reset_after_trigger: bool):
)
)

available_labels = selector.get_available_labels(
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

# we want all the labels belonging to the first trigger
assert len(available_labels) == 3
assert sorted(available_labels) == [0, 1, 189]

selector.inform_data_and_trigger(
DataInformRequest(
pipeline_id=pipeline_id,
Expand All @@ -931,14 +939,26 @@ def test_get_available_labels(reset_after_trigger: bool):
)
)

# this label (99) should not appear in the available labels since it belongs to a future trigger.
selector.inform_data(
DataInformRequest(
pipeline_id=pipeline_id,
keys=[99],
timestamps=[7],
labels=[99],
)
)

available_labels = selector.get_available_labels(
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

if reset_after_trigger:
# only the last trigger must be considered but not point99
assert len(available_labels) == 3
assert sorted(available_labels) == [7, 10, 45]
else:
# every past point must be considered. Only point99 is excluded.
assert len(available_labels) == 6
assert sorted(available_labels) == [0, 1, 7, 10, 45, 189]

Expand Down
4 changes: 2 additions & 2 deletions modyn/models/articlenet/articlenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetMethodsSupport
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn
from transformers import DistilBertModel

Expand Down Expand Up @@ -44,7 +44,7 @@ def __call__(self, data: torch.Tensor) -> torch.Tensor:
return pooled_output


class ArticleNetwork(CoresetMethodsSupport):
class ArticleNetwork(CoresetSupportingModule):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.featurizer = DistilBertFeaturizer.from_pretrained("distilbert-base-uncased")
Expand Down
2 changes: 1 addition & 1 deletion modyn/models/coreset_methods_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def end_recording(self) -> None:
self.embedding = None


class CoresetMethodsSupport(nn.Module, ABC):
class CoresetSupportingModule(nn.Module, ABC):
"""
This class is used to support some Coreset Methods.
Embeddings, here defined as the activation before the last layer, are often used to estimate the importance of
Expand Down
4 changes: 2 additions & 2 deletions modyn/models/dlrm/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import torch
from modyn.models.coreset_methods_support import CoresetMethodsSupport, EmbeddingRecorder
from modyn.models.coreset_methods_support import CoresetSupportingModule, EmbeddingRecorder
from modyn.models.dlrm.nn.factories import create_interaction
from modyn.models.dlrm.nn.parts import DlrmBottom, DlrmTop
from modyn.models.dlrm.utils.install_lib import install_cuda_extensions_if_not_present
Expand All @@ -17,7 +17,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class DlrmModel(CoresetMethodsSupport):
class DlrmModel(CoresetSupportingModule):
# pylint: disable=too-many-instance-attributes
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions modyn/models/dlrm/nn/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional, Sequence, Tuple

import torch
from modyn.models.coreset_methods_support import CoresetMethodsSupport
from modyn.models.coreset_methods_support import CoresetSupportingModule
from modyn.models.dlrm.nn.embeddings import Embeddings
from modyn.models.dlrm.nn.factories import create_embeddings, create_mlp
from modyn.models.dlrm.nn.interactions import Interaction
Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self, numerical_input, categorical_inputs) -> Tuple[torch.Tensor, Op
return torch.cat(bottom_output, dim=1), bottom_mlp_output


class DlrmTop(CoresetMethodsSupport):
class DlrmTop(CoresetSupportingModule):
def __init__(self, top_mlp_sizes: Sequence[int], interaction: Interaction, device: str, use_cpp_mlp: bool = False):
super().__init__()

Expand Down
4 changes: 2 additions & 2 deletions modyn/models/fmownet/fmownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn.functional as F
from modyn.models.coreset_methods_support import CoresetMethodsSupport
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn
from torchvision.models import densenet121

Expand All @@ -20,7 +20,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class FmowNetModel(CoresetMethodsSupport):
class FmowNetModel(CoresetSupportingModule):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.num_classes = num_classes
Expand Down
4 changes: 2 additions & 2 deletions modyn/models/resnet18/resnet18.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetMethodsSupport
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import Tensor, nn
from torchvision.models.resnet import BasicBlock, ResNet

Expand All @@ -17,7 +17,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py


class ResNet18Modyn(ResNet, CoresetMethodsSupport):
class ResNet18Modyn(ResNet, CoresetSupportingModule):
def __init__(self, model_configuration: dict[str, Any]) -> None:
super().__init__(BasicBlock, [2, 2, 2, 2], **model_configuration)

Expand Down
4 changes: 2 additions & 2 deletions modyn/models/yearbooknet/yearbooknet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetMethodsSupport
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn


Expand All @@ -18,7 +18,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class YearbookNetModel(CoresetMethodsSupport):
class YearbookNetModel(CoresetSupportingModule):
def __init__(self, num_input_channels: int, num_classes: int) -> None:
super().__init__()
self.enc = nn.Sequential(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def trigger(self) -> tuple[int, int, int]:
tuple[int, int, int]: Trigger ID, how many keys are in the trigger, number of overall partitions
"""
# TODO(#276) Unify AbstractSelection Strategy and LocalDatasetWriter

trigger_id = self._next_trigger_id
total_keys_in_trigger = 0

Expand Down Expand Up @@ -404,6 +405,7 @@ def get_available_labels(self) -> list[int]:
database.session.query(SelectorStateMetadata.label)
.filter(
SelectorStateMetadata.pipeline_id == self._pipeline_id,
SelectorStateMetadata.seen_in_trigger_id < self._next_trigger_id,
SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers - 1
if self.tail_triggers is not None
else True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_get_available_labels_reset():
@patch.multiple(AbstractSelectionStrategy, __abstractmethods__=set())
def test_get_available_labels_no_reset():
with MetadataDatabaseConnection(get_minimal_modyn_config()) as database:
# first trigger
# first batch of data
database.session.add(
SelectorStateMetadata(pipeline_id=1, sample_key=0, seen_in_trigger_id=0, timestamp=0, label=1)
)
Expand All @@ -346,10 +346,13 @@ def test_get_available_labels_no_reset():

abstr = AbstractSelectionStrategy({"limit": -1, "reset_after_trigger": False}, get_minimal_modyn_config(), 1, 1000)

assert sorted(abstr.get_available_labels()) == []
# simulate a trigger
abstr._next_trigger_id +=1
assert sorted(abstr.get_available_labels()) == [0, 1, 18]

with MetadataDatabaseConnection(get_minimal_modyn_config()) as database:
# second trigger
# another batch of data is inserted with just one more class
database.session.add(
SelectorStateMetadata(pipeline_id=1, sample_key=4, seen_in_trigger_id=1, timestamp=0, label=0)
)
Expand All @@ -358,5 +361,7 @@ def test_get_available_labels_no_reset():
)
database.session.commit()

assert sorted(abstr.get_available_labels()) == [0, 1, 18]
# simulate a trigger
abstr._next_trigger_id += 1
assert sorted(abstr.get_available_labels()) == [0, 1, 18, 890]

0 comments on commit 1a0029c

Please sign in to comment.