diff --git a/.github/workflows/workflow.yaml b/.github/workflows/workflow.yaml index 828ee59e3..04ce7f662 100644 --- a/.github/workflows/workflow.yaml +++ b/.github/workflows/workflow.yaml @@ -7,7 +7,7 @@ defaults: jobs: flake8: - timeout-minutes: 20 + timeout-minutes: 40 runs-on: ubuntu-latest steps: @@ -112,7 +112,7 @@ jobs: # Checks whether the base container works correctly. dockerized-unittests: - timeout-minutes: 30 + timeout-minutes: 60 runs-on: ubuntu-latest needs: - flake8 @@ -136,7 +136,7 @@ jobs: # Tests whether docker-compose up starts all components successfully and integration tests run through # Only one job to reduce Github CI usage integrationtests: - timeout-minutes: 30 + timeout-minutes: 60 runs-on: ubuntu-latest needs: - flake8 diff --git a/integrationtests/selector/integrationtest_selector.py b/integrationtests/selector/integrationtest_selector.py index bf2efb152..c88bbed83 100644 --- a/integrationtests/selector/integrationtest_selector.py +++ b/integrationtests/selector/integrationtest_selector.py @@ -910,8 +910,8 @@ def test_get_available_labels(reset_after_trigger: bool): GetAvailableLabelsRequest(pipeline_id=pipeline_id) ).available_labels - assert len(available_labels) == 2 - assert 0 in available_labels and 1 in available_labels + # here we expect to have 0 labels since it's before the first trigger + assert len(available_labels) == 0 selector.inform_data_and_trigger( DataInformRequest( @@ -922,7 +922,15 @@ def test_get_available_labels(reset_after_trigger: bool): ) ) - selector.inform_data( + available_labels = selector.get_available_labels( + GetAvailableLabelsRequest(pipeline_id=pipeline_id) + ).available_labels + + # we want all the labels belonging to the first trigger + assert len(available_labels) == 3 + assert sorted(available_labels) == [0, 1, 189] + + selector.inform_data_and_trigger( DataInformRequest( pipeline_id=pipeline_id, keys=[4, 5, 6], @@ -930,14 +938,27 @@ def test_get_available_labels(reset_after_trigger: bool): labels=[10, 7, 45], ) ) + + # this label (99) should not appear in the available labels since it belongs to a future trigger. + selector.inform_data( + DataInformRequest( + pipeline_id=pipeline_id, + keys=[99], + timestamps=[7], + labels=[99], + ) + ) + available_labels = selector.get_available_labels( GetAvailableLabelsRequest(pipeline_id=pipeline_id) ).available_labels if reset_after_trigger: + # only the last trigger must be considered but not point99 assert len(available_labels) == 3 assert sorted(available_labels) == [7, 10, 45] else: + # every past point must be considered. Only point99 is excluded. assert len(available_labels) == 6 assert sorted(available_labels) == [0, 1, 7, 10, 45, 189] diff --git a/modyn/models/README.md b/modyn/models/README.md index 07c308bb7..d5f9cddf4 100644 --- a/modyn/models/README.md +++ b/modyn/models/README.md @@ -5,4 +5,9 @@ The user can define models here. The model definition should take as a parameter # Wild Time models The code for the models used for WildTime is taken from the official [repository](https://github.com/huaxiuyao/Wild-Time). The original version is linked in each class. -You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license \ No newline at end of file +You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license + +# Embedding Recorder +Many coreset methods are adapted from the [DeepCore](https://github.com/PatrickZH/DeepCore/) library. To use them, the models must keep track of the embeddings (activations of the penultimate layer). This is +done using the `EmbeddingRecorder` class, which is adapted from the aforementioned project. +You can find a copy of their MIT license [here](https://raw.githubusercontent.com/PatrickZH/DeepCore/main/LICENSE.md) \ No newline at end of file diff --git a/modyn/models/articlenet/articlenet.py b/modyn/models/articlenet/articlenet.py index c22681a6e..244e5c623 100644 --- a/modyn/models/articlenet/articlenet.py +++ b/modyn/models/articlenet/articlenet.py @@ -1,6 +1,7 @@ from typing import Any import torch +from modyn.models.coreset_methods_support import CoresetSupportingModule from torch import nn from transformers import DistilBertModel @@ -43,7 +44,7 @@ def __call__(self, data: torch.Tensor) -> torch.Tensor: return pooled_output -class ArticleNetwork(nn.Module): +class ArticleNetwork(CoresetSupportingModule): def __init__(self, num_classes: int) -> None: super().__init__() self.featurizer = DistilBertFeaturizer.from_pretrained("distilbert-base-uncased") @@ -51,4 +52,8 @@ def __init__(self, num_classes: int) -> None: def forward(self, data: torch.Tensor) -> torch.Tensor: embedding = self.featurizer(data) + embedding = self.embedding_recorder(embedding) return self.classifier(embedding) + + def get_last_layer(self) -> nn.Module: + return self.classifier diff --git a/modyn/models/coreset_methods_support.py b/modyn/models/coreset_methods_support.py new file mode 100644 index 000000000..ab669df78 --- /dev/null +++ b/modyn/models/coreset_methods_support.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from torch import nn + + +# acknowledgment: github.com/PatrickZH/DeepCore/ +class EmbeddingRecorder(nn.Module): + def __init__(self, record_embedding: bool = False): + super().__init__() + self.record_embedding = record_embedding + self.embedding: Optional[torch.Tensor] = None + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + if self.record_embedding: + self.embedding = tensor + return tensor + + def start_recording(self) -> None: + self.record_embedding = True + + def end_recording(self) -> None: + self.record_embedding = False + self.embedding = None + + +class CoresetSupportingModule(nn.Module, ABC): + """ + This class is used to support some Coreset Methods. + Embeddings, here defined as the activation before the last layer, are often used to estimate the importance of + a point. To implement this class correctly, it is necessary to + - implement the get_last_layer method + - modify the forward pass so that the last layer embedding is recorded. For example, in a simple network like + x = self.fc1(input) + x = self.fc2(x) + output = self.fc3(x) + it must be modified as follows + x = self.fc1(input) + x = self.fc2(x) + x = self.embedding_recorder(x) + output = self.fc3(x) + """ + + def __init__(self, record_embedding: bool = False) -> None: + super().__init__() + self.embedding_recorder = EmbeddingRecorder(record_embedding) + + @property + def embedding(self) -> Optional[torch.Tensor]: + assert self.embedding_recorder is not None + return self.embedding_recorder.embedding + + @abstractmethod + def get_last_layer(self) -> nn.Module: + """ + Returns the last layer. Used for example to obtain the pre-layer and post-layer dimensions of tensors + + """ + raise NotImplementedError() diff --git a/modyn/models/dlrm/dlrm.py b/modyn/models/dlrm/dlrm.py index 86e9657f5..1c0b050d8 100644 --- a/modyn/models/dlrm/dlrm.py +++ b/modyn/models/dlrm/dlrm.py @@ -1,7 +1,8 @@ -from typing import Any +from typing import Any, Optional import numpy as np import torch +from modyn.models.coreset_methods_support import CoresetSupportingModule, EmbeddingRecorder from modyn.models.dlrm.nn.factories import create_interaction from modyn.models.dlrm.nn.parts import DlrmBottom, DlrmTop from modyn.models.dlrm.utils.install_lib import install_cuda_extensions_if_not_present @@ -16,7 +17,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) self.model.to(device) -class DlrmModel(nn.Module): +class DlrmModel(CoresetSupportingModule): # pylint: disable=too-many-instance-attributes def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: super().__init__() @@ -124,3 +125,15 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: numerical_input, self.reorder_categorical_input(categorical_input) ) return self.top_model(from_bottom, bottom_mlp_output).squeeze() + + # delegate the embedding handling to the top model + @property + def embedding(self) -> Optional[torch.Tensor]: + return self.top_model.embedding + + @property + def embedding_recorder(self) -> EmbeddingRecorder: + return self.top_model.embedding_recorder + + def get_last_layer(self) -> nn.Module: + return self.top_model.get_last_layer() diff --git a/modyn/models/dlrm/nn/parts.py b/modyn/models/dlrm/nn/parts.py index 7d9e00535..2f19b9897 100644 --- a/modyn/models/dlrm/nn/parts.py +++ b/modyn/models/dlrm/nn/parts.py @@ -19,6 +19,7 @@ from typing import Optional, Sequence, Tuple import torch +from modyn.models.coreset_methods_support import CoresetSupportingModule from modyn.models.dlrm.nn.embeddings import Embeddings from modyn.models.dlrm.nn.factories import create_embeddings, create_mlp from modyn.models.dlrm.nn.interactions import Interaction @@ -104,7 +105,7 @@ def forward(self, numerical_input, categorical_inputs) -> Tuple[torch.Tensor, Op return torch.cat(bottom_output, dim=1), bottom_mlp_output -class DlrmTop(nn.Module): +class DlrmTop(CoresetSupportingModule): def __init__(self, top_mlp_sizes: Sequence[int], interaction: Interaction, device: str, use_cpp_mlp: bool = False): super().__init__() @@ -127,4 +128,9 @@ def forward(self, bottom_output, bottom_mlp_output): bottom_mlp_output (Tensor): with shape [batch_size, embedding_dim] """ interaction_output = self.interaction.interact(bottom_output, bottom_mlp_output) - return self.out(self.mlp(interaction_output)) + mlp_output = self.mlp(interaction_output) + mlp_output = self.embedding_recorder(mlp_output) + return self.out(mlp_output) + + def get_last_layer(self) -> nn.Module: + return self.out diff --git a/modyn/models/fmownet/fmownet.py b/modyn/models/fmownet/fmownet.py index 1be61a5b0..d8c9caa22 100644 --- a/modyn/models/fmownet/fmownet.py +++ b/modyn/models/fmownet/fmownet.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +from modyn.models.coreset_methods_support import CoresetSupportingModule from torch import nn from torchvision.models import densenet121 @@ -19,7 +20,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) self.model.to(device) -class FmowNetModel(nn.Module): +class FmowNetModel(CoresetSupportingModule): def __init__(self, num_classes: int) -> None: super().__init__() self.num_classes = num_classes @@ -31,5 +32,8 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out, 1) - + out = self.embedding_recorder(out) return self.classifier(out) + + def get_last_layer(self) -> nn.Module: + return self.classifier diff --git a/modyn/models/resnet18/resnet18.py b/modyn/models/resnet18/resnet18.py index 0fb65766a..2b7f1da04 100644 --- a/modyn/models/resnet18/resnet18.py +++ b/modyn/models/resnet18/resnet18.py @@ -1,10 +1,44 @@ from typing import Any -from torchvision import models +import torch +from modyn.models.coreset_methods_support import CoresetSupportingModule +from torch import Tensor, nn +from torchvision.models.resnet import BasicBlock, ResNet class ResNet18: # pylint: disable-next=unused-argument def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: - self.model = models.__dict__["resnet18"](**model_configuration) + self.model = ResNet18Modyn(model_configuration) self.model.to(device) + + +# the following class is adapted from +# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py + + +class ResNet18Modyn(ResNet, CoresetSupportingModule): + def __init__(self, model_configuration: dict[str, Any]) -> None: + super().__init__(BasicBlock, [2, 2, 2, 2], **model_configuration) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + # the following line is the only difference compared to the original implementation + x = self.embedding_recorder(x) + x = self.fc(x) + + return x + + def get_last_layer(self) -> nn.Module: + return self.fc diff --git a/modyn/models/yearbooknet/yearbooknet.py b/modyn/models/yearbooknet/yearbooknet.py index a0f1b4dc6..032e09692 100644 --- a/modyn/models/yearbooknet/yearbooknet.py +++ b/modyn/models/yearbooknet/yearbooknet.py @@ -1,6 +1,7 @@ from typing import Any import torch +from modyn.models.coreset_methods_support import CoresetSupportingModule from torch import nn @@ -17,7 +18,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) self.model.to(device) -class YearbookNetModel(nn.Module): +class YearbookNetModel(CoresetSupportingModule): def __init__(self, num_input_channels: int, num_classes: int) -> None: super().__init__() self.enc = nn.Sequential( @@ -37,5 +38,8 @@ def conv_block(self, in_channels: int, out_channels: int) -> nn.Module: def forward(self, data: torch.Tensor) -> torch.Tensor: data = self.enc(data) data = torch.mean(data, dim=(2, 3)) - + data = self.embedding_recorder(data) return self.classifier(data) + + def get_last_layer(self) -> nn.Module: + return self.classifier diff --git a/modyn/selector/internal/selector_strategies/abstract_selection_strategy.py b/modyn/selector/internal/selector_strategies/abstract_selection_strategy.py index b91925fec..533658706 100644 --- a/modyn/selector/internal/selector_strategies/abstract_selection_strategy.py +++ b/modyn/selector/internal/selector_strategies/abstract_selection_strategy.py @@ -186,6 +186,7 @@ def trigger(self) -> tuple[int, int, int]: tuple[int, int, int]: Trigger ID, how many keys are in the trigger, number of overall partitions """ # TODO(#276) Unify AbstractSelection Strategy and LocalDatasetWriter + trigger_id = self._next_trigger_id total_keys_in_trigger = 0 @@ -404,7 +405,8 @@ def get_available_labels(self) -> list[int]: database.session.query(SelectorStateMetadata.label) .filter( SelectorStateMetadata.pipeline_id == self._pipeline_id, - SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers + SelectorStateMetadata.seen_in_trigger_id < self._next_trigger_id, + SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers - 1 if self.tail_triggers is not None else True, ) diff --git a/modyn/tests/models/test_dlrm.py b/modyn/tests/models/test_dlrm.py index b0fb2607b..4471b0b2d 100644 --- a/modyn/tests/models/test_dlrm.py +++ b/modyn/tests/models/test_dlrm.py @@ -67,3 +67,48 @@ def test_dlrm_reorder_categorical_input(): assert reordered_test_data.shape == (64, 26) assert reordered_test_data.dtype == torch.long assert torch.equal(reordered_test_data, input_data) + + +def test_get_last_layer(): + net = DLRM(get_dlrm_configuration(), "cpu", False) + last_layer = net.model.get_last_layer() + + assert isinstance(last_layer, torch.nn.Linear) + assert last_layer.in_features == 16 + assert last_layer.out_features == 1 + assert last_layer.bias.shape == (1,) + assert last_layer.weight.shape == (1, 16) + + +def test_dlrm_no_side_effect(): + model = DLRM(get_dlrm_configuration(), "cpu", False) + + data = { + "numerical_input": torch.ones((64, 13), dtype=torch.float32), + "categorical_input": torch.ones((64, 26), dtype=torch.long), + } + out_off = model.model(data) + model.model.embedding_recorder.record_embedding = True + out_on = model.model(data) + + assert torch.equal(out_on, out_off) + + +def test_shape_embedding_recorder(): + model = DLRM(get_dlrm_configuration(), "cpu", False) + + data = { + "numerical_input": torch.ones((64, 13), dtype=torch.float32), + "categorical_input": torch.ones((64, 26), dtype=torch.long), + } + model.model(data) + assert model.model.embedding is None + model.model.embedding_recorder.record_embedding = True + + last_layer = model.model.get_last_layer() + recorded_output = model.model(data) + recorded_embedding = model.model.embedding + + assert recorded_embedding is not None + assert recorded_embedding.shape == (64, last_layer.in_features) + assert torch.equal(torch.squeeze(last_layer(recorded_embedding)), recorded_output) diff --git a/modyn/tests/models/test_embedding_recorder.py b/modyn/tests/models/test_embedding_recorder.py new file mode 100644 index 000000000..83c3c63c8 --- /dev/null +++ b/modyn/tests/models/test_embedding_recorder.py @@ -0,0 +1,33 @@ +import torch +from modyn.models.coreset_methods_support import EmbeddingRecorder + + +def test_embedding_recording(): + recorder = EmbeddingRecorder() + recorder.start_recording() + input_tensor = torch.tensor([1, 2, 3]) + output_tensor = recorder(input_tensor) + assert torch.equal(recorder.embedding, input_tensor) + assert torch.equal(output_tensor, input_tensor) + + +def test_no_embedding_recording(): + recorder = EmbeddingRecorder() + input_tensor = torch.tensor([4, 5, 6]) + output_tensor = recorder(input_tensor) + assert recorder.embedding is None + assert torch.equal(output_tensor, input_tensor) + + +def test_toggle_embedding_recording(): + recorder = EmbeddingRecorder() + recorder.start_recording() + input_tensor = torch.tensor([7, 8, 9]) + output_tensor = recorder(input_tensor) + assert torch.equal(recorder.embedding, input_tensor) + assert torch.equal(output_tensor, input_tensor) + recorder.end_recording() + input_tensor = torch.tensor([10, 11, 12]) + output_tensor = recorder(input_tensor) + assert recorder.embedding is None + assert torch.equal(output_tensor, input_tensor) diff --git a/modyn/tests/models/test_fmownet.py b/modyn/tests/models/test_fmownet.py new file mode 100644 index 000000000..26a93b5ae --- /dev/null +++ b/modyn/tests/models/test_fmownet.py @@ -0,0 +1,33 @@ +import torch +import torch.nn.functional as F +from modyn.models import FmowNet + + +def test_forward_with_embedding_recording(): + net = FmowNet({"num_classes": 10}, "cpu", False) + input_data = torch.rand(30, 3, 32, 32) + output = net.model(input_data) + assert output.shape == (30, 10) + assert net.model.embedding is None + + net.model.embedding_recorder.start_recording() + input_data = torch.rand(30, 3, 32, 32) + output = net.model(input_data) + assert output.shape == (30, 10) + + assert net.model.embedding is not None + + expected_embedding = F.adaptive_avg_pool2d(F.relu(net.model.enc(input_data), inplace=True), (1, 1)) + assert torch.equal(torch.flatten(expected_embedding, 1), net.model.embedding) + assert torch.equal(net.model.classifier(net.model.embedding), output) + + +def test_get_last_layer(): + net = FmowNet({"num_classes": 10}, "cpu", False) + last_layer = net.model.get_last_layer() + + assert isinstance(last_layer, torch.nn.Linear) + assert last_layer.in_features == 1024 + assert last_layer.out_features == 10 + assert last_layer.bias.shape == (10,) + assert last_layer.weight.shape == (10, 1024) diff --git a/modyn/tests/models/test_resnet18.py b/modyn/tests/models/test_resnet18.py new file mode 100644 index 000000000..1fc004eb9 --- /dev/null +++ b/modyn/tests/models/test_resnet18.py @@ -0,0 +1,29 @@ +import torch +from modyn.models import ResNet18 + + +def test_forward_with_embedding_recording(): + net = ResNet18({"num_classes": 10}, "cpu", False) + input_data = torch.rand(30, 3, 32, 32) + output = net.model(input_data) + assert output.shape == (30, 10) + assert net.model.embedding is None + + net.model.embedding_recorder.start_recording() + input_data = torch.rand(30, 3, 32, 32) + output = net.model(input_data) + assert output.shape == (30, 10) + + assert net.model.embedding is not None + assert torch.equal(net.model.fc(net.model.embedding), output) + + +def test_get_last_layer(): + net = ResNet18({"num_classes": 10}, "cpu", False) + last_layer = net.model.get_last_layer() + + assert isinstance(last_layer, torch.nn.Linear) + assert last_layer.in_features == 512 + assert last_layer.out_features == 10 + assert last_layer.bias.shape == (10,) + assert last_layer.weight.shape == (10, 512) diff --git a/modyn/tests/models/test_yearbook_net.py b/modyn/tests/models/test_yearbook_net.py index a21526c43..e3d94e92d 100644 --- a/modyn/tests/models/test_yearbook_net.py +++ b/modyn/tests/models/test_yearbook_net.py @@ -42,3 +42,34 @@ def test_model_conv_block(): # Assert that the output has the correct shape assert output.shape == (batch_size, 32, height // 2, width // 2) + + +def test_forward_with_embedding_recording(): + net = YearbookNet({"num_input_channels": 3, "num_classes": 10}, "cpu", False) + input_data = torch.rand(30, 3, 32, 32) + output = net.model(input_data) + assert output.shape == (30, 10) + assert net.model.embedding is None + + net.model.embedding_recorder.start_recording() + input_data = torch.rand(30, 3, 32, 32) + output = net.model(input_data) + assert output.shape == (30, 10) + + # expected embedding + expected_embedding = net.model.enc(input_data) + expected_embedding = torch.mean(expected_embedding, dim=(2, 3)) + + assert net.model.embedding is not None + assert torch.equal(net.model.embedding, expected_embedding) + + +def test_get_last_layer(): + net = YearbookNet({"num_input_channels": 3, "num_classes": 10}, "cpu", False) + last_layer = net.model.get_last_layer() + + assert isinstance(last_layer, torch.nn.Linear) + assert last_layer.in_features == 32 + assert last_layer.out_features == 10 + assert last_layer.bias.shape == (10,) + assert last_layer.weight.shape == (10, 32) diff --git a/modyn/tests/selector/internal/selector_strategies/test_abstract_selection_strategy.py b/modyn/tests/selector/internal/selector_strategies/test_abstract_selection_strategy.py index 72968818c..0e2e78f8e 100644 --- a/modyn/tests/selector/internal/selector_strategies/test_abstract_selection_strategy.py +++ b/modyn/tests/selector/internal/selector_strategies/test_abstract_selection_strategy.py @@ -309,7 +309,7 @@ def test_get_available_labels_reset(): database.session.commit() abstr = AbstractSelectionStrategy({"limit": -1, "reset_after_trigger": True}, get_minimal_modyn_config(), 1, 1000) - + abstr._next_trigger_id += 1 assert sorted(abstr.get_available_labels()) == [0, 1, 18] with MetadataDatabaseConnection(get_minimal_modyn_config()) as database: @@ -329,7 +329,7 @@ def test_get_available_labels_reset(): @patch.multiple(AbstractSelectionStrategy, __abstractmethods__=set()) def test_get_available_labels_no_reset(): with MetadataDatabaseConnection(get_minimal_modyn_config()) as database: - # first trigger + # first batch of data database.session.add( SelectorStateMetadata(pipeline_id=1, sample_key=0, seen_in_trigger_id=0, timestamp=0, label=1) ) @@ -346,10 +346,13 @@ def test_get_available_labels_no_reset(): abstr = AbstractSelectionStrategy({"limit": -1, "reset_after_trigger": False}, get_minimal_modyn_config(), 1, 1000) + assert sorted(abstr.get_available_labels()) == [] + # simulate a trigger + abstr._next_trigger_id += 1 assert sorted(abstr.get_available_labels()) == [0, 1, 18] with MetadataDatabaseConnection(get_minimal_modyn_config()) as database: - # second trigger + # another batch of data is inserted with just one more class database.session.add( SelectorStateMetadata(pipeline_id=1, sample_key=4, seen_in_trigger_id=1, timestamp=0, label=0) ) @@ -358,5 +361,7 @@ def test_get_available_labels_no_reset(): ) database.session.commit() + assert sorted(abstr.get_available_labels()) == [0, 1, 18] + # simulate a trigger abstr._next_trigger_id += 1 assert sorted(abstr.get_available_labels()) == [0, 1, 18, 890]