Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CoresetMethodsSupport #294

Merged
merged 28 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
226ceb2
Introduction of Modyn Model
francescodeaglio Jun 22, 2023
fdd4440
Inefficient ResNet18(ModynModel)
francescodeaglio Jun 22, 2023
fbe1720
Update torchvision's resnet
francescodeaglio Jun 24, 2023
2ba9c33
Merge branch 'main' into feature/francescodeaglio/modyn_model
francescodeaglio Jun 24, 2023
15f1ae3
Resnet torchvision simplified
francescodeaglio Jun 28, 2023
9012ba6
Merge branch 'main' into feature/francescodeaglio/modyn_model
francescodeaglio Jul 5, 2023
2b1c862
Merge branch 'main' into feature/francescodeaglio/modyn_model
francescodeaglio Aug 2, 2023
9959482
Changed embedding
francescodeaglio Aug 2, 2023
7cf318d
Merge branch 'main' into feature/francescodeaglio/modyn_model
francescodeaglio Aug 2, 2023
1d1987d
Now every model implements ModynModel
francescodeaglio Aug 2, 2023
4e81dd4
Tests
francescodeaglio Aug 2, 2023
688cd64
Remove useless class
francescodeaglio Aug 2, 2023
d190885
Dlrm embedding recorder
francescodeaglio Aug 2, 2023
0eff629
Reset
francescodeaglio Aug 7, 2023
1f9e595
Renamed ModynModels to CoresetMethodsSupport
francescodeaglio Aug 7, 2023
36140e7
ResNet18
francescodeaglio Aug 7, 2023
cbae4fd
Docstrings
francescodeaglio Aug 7, 2023
7fd7870
Test dlrm
francescodeaglio Aug 7, 2023
b5edf3a
Test resnet
francescodeaglio Aug 7, 2023
17e167e
Test fmownet
francescodeaglio Aug 7, 2023
2ef56e1
Removed context manager. Simplified logic
francescodeaglio Aug 7, 2023
66fb9f1
Fix wrong test
francescodeaglio Aug 7, 2023
1a0029c
Rename CoresetSupportingModule. Fix test
francescodeaglio Sep 15, 2023
15438b9
Merge branch 'main' into feature/francescodeaglio/modyn_model
francescodeaglio Sep 15, 2023
238b5f5
Update workflow.yaml
francescodeaglio Sep 15, 2023
cdfeb53
Reset timeout
francescodeaglio Sep 16, 2023
f0361ea
Update workflow.yaml
francescodeaglio Sep 16, 2023
5e35534
Update workflow.yaml
francescodeaglio Sep 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:

jobs:
flake8:
timeout-minutes: 20
timeout-minutes: 40
runs-on: ubuntu-latest

steps:
Expand Down Expand Up @@ -112,7 +112,7 @@ jobs:

# Checks whether the base container works correctly.
dockerized-unittests:
timeout-minutes: 30
timeout-minutes: 60
runs-on: ubuntu-latest
needs:
- flake8
Expand All @@ -136,7 +136,7 @@ jobs:
# Tests whether docker-compose up starts all components successfully and integration tests run through
# Only one job to reduce Github CI usage
integrationtests:
timeout-minutes: 30
timeout-minutes: 60
runs-on: ubuntu-latest
needs:
- flake8
Expand Down
27 changes: 24 additions & 3 deletions integrationtests/selector/integrationtest_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,8 @@ def test_get_available_labels(reset_after_trigger: bool):
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

assert len(available_labels) == 2
assert 0 in available_labels and 1 in available_labels
# here we expect to have 0 labels since it's before the first trigger
assert len(available_labels) == 0

selector.inform_data_and_trigger(
DataInformRequest(
Expand All @@ -922,22 +922,43 @@ def test_get_available_labels(reset_after_trigger: bool):
)
)

selector.inform_data(
available_labels = selector.get_available_labels(
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

# we want all the labels belonging to the first trigger
assert len(available_labels) == 3
assert sorted(available_labels) == [0, 1, 189]

selector.inform_data_and_trigger(
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
DataInformRequest(
pipeline_id=pipeline_id,
keys=[4, 5, 6],
timestamps=[4, 5, 6],
labels=[10, 7, 45],
)
)

# this label (99) should not appear in the available labels since it belongs to a future trigger.
selector.inform_data(
DataInformRequest(
pipeline_id=pipeline_id,
keys=[99],
timestamps=[7],
labels=[99],
)
)

available_labels = selector.get_available_labels(
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

if reset_after_trigger:
# only the last trigger must be considered but not point99
assert len(available_labels) == 3
assert sorted(available_labels) == [7, 10, 45]
else:
# every past point must be considered. Only point99 is excluded.
assert len(available_labels) == 6
assert sorted(available_labels) == [0, 1, 7, 10, 45, 189]

Expand Down
7 changes: 6 additions & 1 deletion modyn/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ The user can define models here. The model definition should take as a parameter
# Wild Time models
The code for the models used for WildTime is taken from the official [repository](https://github.com/huaxiuyao/Wild-Time).
The original version is linked in each class.
You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license
You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license

# Embedding Recorder
Many coreset methods are adapted from the [DeepCore](https://github.com/PatrickZH/DeepCore/) library. To use them, the models must keep track of the embeddings (activations of the penultimate layer). This is
done using the `EmbeddingRecorder` class, which is adapted from the aforementioned project.
You can find a copy of their MIT license [here](https://raw.githubusercontent.com/PatrickZH/DeepCore/main/LICENSE.md)
7 changes: 6 additions & 1 deletion modyn/models/articlenet/articlenet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn
from transformers import DistilBertModel

Expand Down Expand Up @@ -43,12 +44,16 @@ def __call__(self, data: torch.Tensor) -> torch.Tensor:
return pooled_output


class ArticleNetwork(nn.Module):
class ArticleNetwork(CoresetSupportingModule):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.featurizer = DistilBertFeaturizer.from_pretrained("distilbert-base-uncased")
self.classifier = nn.Linear(self.featurizer.d_out, num_classes)

def forward(self, data: torch.Tensor) -> torch.Tensor:
embedding = self.featurizer(data)
embedding = self.embedding_recorder(embedding)
return self.classifier(embedding)

def get_last_layer(self) -> nn.Module:
return self.classifier
60 changes: 60 additions & 0 deletions modyn/models/coreset_methods_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import nn


# acknowledgment: github.com/PatrickZH/DeepCore/
class EmbeddingRecorder(nn.Module):
def __init__(self, record_embedding: bool = False):
super().__init__()
self.record_embedding = record_embedding
self.embedding: Optional[torch.Tensor] = None

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
if self.record_embedding:
self.embedding = tensor
return tensor

def start_recording(self) -> None:
self.record_embedding = True

def end_recording(self) -> None:
self.record_embedding = False
self.embedding = None


class CoresetSupportingModule(nn.Module, ABC):
"""
This class is used to support some Coreset Methods.
Embeddings, here defined as the activation before the last layer, are often used to estimate the importance of
a point. To implement this class correctly, it is necessary to
- implement the get_last_layer method
- modify the forward pass so that the last layer embedding is recorded. For example, in a simple network like
x = self.fc1(input)
x = self.fc2(x)
output = self.fc3(x)
it must be modified as follows
x = self.fc1(input)
x = self.fc2(x)
x = self.embedding_recorder(x)
output = self.fc3(x)
"""

def __init__(self, record_embedding: bool = False) -> None:
super().__init__()
self.embedding_recorder = EmbeddingRecorder(record_embedding)

@property
def embedding(self) -> Optional[torch.Tensor]:
assert self.embedding_recorder is not None
return self.embedding_recorder.embedding

@abstractmethod
def get_last_layer(self) -> nn.Module:
"""
Returns the last layer. Used for example to obtain the pre-layer and post-layer dimensions of tensors

"""
raise NotImplementedError()
17 changes: 15 additions & 2 deletions modyn/models/dlrm/dlrm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any
from typing import Any, Optional

import numpy as np
import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule, EmbeddingRecorder
from modyn.models.dlrm.nn.factories import create_interaction
from modyn.models.dlrm.nn.parts import DlrmBottom, DlrmTop
from modyn.models.dlrm.utils.install_lib import install_cuda_extensions_if_not_present
Expand All @@ -16,7 +17,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class DlrmModel(nn.Module):
class DlrmModel(CoresetSupportingModule):
# pylint: disable=too-many-instance-attributes
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
super().__init__()
Expand Down Expand Up @@ -124,3 +125,15 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
numerical_input, self.reorder_categorical_input(categorical_input)
)
return self.top_model(from_bottom, bottom_mlp_output).squeeze()

# delegate the embedding handling to the top model
@property
def embedding(self) -> Optional[torch.Tensor]:
return self.top_model.embedding

@property
def embedding_recorder(self) -> EmbeddingRecorder:
return self.top_model.embedding_recorder

def get_last_layer(self) -> nn.Module:
return self.top_model.get_last_layer()
10 changes: 8 additions & 2 deletions modyn/models/dlrm/nn/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Sequence, Tuple

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from modyn.models.dlrm.nn.embeddings import Embeddings
from modyn.models.dlrm.nn.factories import create_embeddings, create_mlp
from modyn.models.dlrm.nn.interactions import Interaction
Expand Down Expand Up @@ -104,7 +105,7 @@ def forward(self, numerical_input, categorical_inputs) -> Tuple[torch.Tensor, Op
return torch.cat(bottom_output, dim=1), bottom_mlp_output


class DlrmTop(nn.Module):
class DlrmTop(CoresetSupportingModule):
def __init__(self, top_mlp_sizes: Sequence[int], interaction: Interaction, device: str, use_cpp_mlp: bool = False):
super().__init__()

Expand All @@ -127,4 +128,9 @@ def forward(self, bottom_output, bottom_mlp_output):
bottom_mlp_output (Tensor): with shape [batch_size, embedding_dim]
"""
interaction_output = self.interaction.interact(bottom_output, bottom_mlp_output)
return self.out(self.mlp(interaction_output))
mlp_output = self.mlp(interaction_output)
mlp_output = self.embedding_recorder(mlp_output)
return self.out(mlp_output)

def get_last_layer(self) -> nn.Module:
return self.out
8 changes: 6 additions & 2 deletions modyn/models/fmownet/fmownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn.functional as F
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn
from torchvision.models import densenet121

Expand All @@ -19,7 +20,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class FmowNetModel(nn.Module):
class FmowNetModel(CoresetSupportingModule):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.num_classes = num_classes
Expand All @@ -31,5 +32,8 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)

out = self.embedding_recorder(out)
return self.classifier(out)

def get_last_layer(self) -> nn.Module:
return self.classifier
38 changes: 36 additions & 2 deletions modyn/models/resnet18/resnet18.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
from typing import Any

from torchvision import models
import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import Tensor, nn
from torchvision.models.resnet import BasicBlock, ResNet


class ResNet18:
# pylint: disable-next=unused-argument
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
self.model = models.__dict__["resnet18"](**model_configuration)
self.model = ResNet18Modyn(model_configuration)
self.model.to(device)


# the following class is adapted from
# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py


class ResNet18Modyn(ResNet, CoresetSupportingModule):
def __init__(self, model_configuration: dict[str, Any]) -> None:
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(BasicBlock, [2, 2, 2, 2], **model_configuration)
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved

def _forward_impl(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
# the following line is the only difference compared to the original implementation
x = self.embedding_recorder(x)
x = self.fc(x)

return x

def get_last_layer(self) -> nn.Module:
return self.fc
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 6 additions & 2 deletions modyn/models/yearbooknet/yearbooknet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn


Expand All @@ -17,7 +18,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class YearbookNetModel(nn.Module):
class YearbookNetModel(CoresetSupportingModule):
def __init__(self, num_input_channels: int, num_classes: int) -> None:
super().__init__()
self.enc = nn.Sequential(
Expand All @@ -37,5 +38,8 @@ def conv_block(self, in_channels: int, out_channels: int) -> nn.Module:
def forward(self, data: torch.Tensor) -> torch.Tensor:
data = self.enc(data)
data = torch.mean(data, dim=(2, 3))

data = self.embedding_recorder(data)
return self.classifier(data)

def get_last_layer(self) -> nn.Module:
return self.classifier
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def trigger(self) -> tuple[int, int, int]:
tuple[int, int, int]: Trigger ID, how many keys are in the trigger, number of overall partitions
"""
# TODO(#276) Unify AbstractSelection Strategy and LocalDatasetWriter

trigger_id = self._next_trigger_id
total_keys_in_trigger = 0

Expand Down Expand Up @@ -404,7 +405,8 @@ def get_available_labels(self) -> list[int]:
database.session.query(SelectorStateMetadata.label)
.filter(
SelectorStateMetadata.pipeline_id == self._pipeline_id,
SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers
SelectorStateMetadata.seen_in_trigger_id < self._next_trigger_id,
SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers - 1
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
if self.tail_triggers is not None
else True,
)
Expand Down
Loading