Skip to content

Commit

Permalink
Introduce CoresetMethodsSupport (#294)
Browse files Browse the repository at this point in the history
Many DeepCore methods require obtaining the last layer and using an
embedding recorder. This PR introduces the CoresetSupportingModule class
that must be inherited to use these methods.
  • Loading branch information
francescodeaglio authored Sep 18, 2023
1 parent 013e7af commit 0c1aa65
Show file tree
Hide file tree
Showing 17 changed files with 352 additions and 22 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:

jobs:
flake8:
timeout-minutes: 20
timeout-minutes: 40
runs-on: ubuntu-latest

steps:
Expand Down Expand Up @@ -112,7 +112,7 @@ jobs:

# Checks whether the base container works correctly.
dockerized-unittests:
timeout-minutes: 30
timeout-minutes: 60
runs-on: ubuntu-latest
needs:
- flake8
Expand All @@ -136,7 +136,7 @@ jobs:
# Tests whether docker-compose up starts all components successfully and integration tests run through
# Only one job to reduce Github CI usage
integrationtests:
timeout-minutes: 30
timeout-minutes: 60
runs-on: ubuntu-latest
needs:
- flake8
Expand Down
27 changes: 24 additions & 3 deletions integrationtests/selector/integrationtest_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,8 @@ def test_get_available_labels(reset_after_trigger: bool):
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

assert len(available_labels) == 2
assert 0 in available_labels and 1 in available_labels
# here we expect to have 0 labels since it's before the first trigger
assert len(available_labels) == 0

selector.inform_data_and_trigger(
DataInformRequest(
Expand All @@ -922,22 +922,43 @@ def test_get_available_labels(reset_after_trigger: bool):
)
)

selector.inform_data(
available_labels = selector.get_available_labels(
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

# we want all the labels belonging to the first trigger
assert len(available_labels) == 3
assert sorted(available_labels) == [0, 1, 189]

selector.inform_data_and_trigger(
DataInformRequest(
pipeline_id=pipeline_id,
keys=[4, 5, 6],
timestamps=[4, 5, 6],
labels=[10, 7, 45],
)
)

# this label (99) should not appear in the available labels since it belongs to a future trigger.
selector.inform_data(
DataInformRequest(
pipeline_id=pipeline_id,
keys=[99],
timestamps=[7],
labels=[99],
)
)

available_labels = selector.get_available_labels(
GetAvailableLabelsRequest(pipeline_id=pipeline_id)
).available_labels

if reset_after_trigger:
# only the last trigger must be considered but not point99
assert len(available_labels) == 3
assert sorted(available_labels) == [7, 10, 45]
else:
# every past point must be considered. Only point99 is excluded.
assert len(available_labels) == 6
assert sorted(available_labels) == [0, 1, 7, 10, 45, 189]

Expand Down
7 changes: 6 additions & 1 deletion modyn/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ The user can define models here. The model definition should take as a parameter
# Wild Time models
The code for the models used for WildTime is taken from the official [repository](https://github.com/huaxiuyao/Wild-Time).
The original version is linked in each class.
You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license
You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license

# Embedding Recorder
Many coreset methods are adapted from the [DeepCore](https://github.com/PatrickZH/DeepCore/) library. To use them, the models must keep track of the embeddings (activations of the penultimate layer). This is
done using the `EmbeddingRecorder` class, which is adapted from the aforementioned project.
You can find a copy of their MIT license [here](https://raw.githubusercontent.com/PatrickZH/DeepCore/main/LICENSE.md)
7 changes: 6 additions & 1 deletion modyn/models/articlenet/articlenet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn
from transformers import DistilBertModel

Expand Down Expand Up @@ -43,12 +44,16 @@ def __call__(self, data: torch.Tensor) -> torch.Tensor:
return pooled_output


class ArticleNetwork(nn.Module):
class ArticleNetwork(CoresetSupportingModule):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.featurizer = DistilBertFeaturizer.from_pretrained("distilbert-base-uncased")
self.classifier = nn.Linear(self.featurizer.d_out, num_classes)

def forward(self, data: torch.Tensor) -> torch.Tensor:
embedding = self.featurizer(data)
embedding = self.embedding_recorder(embedding)
return self.classifier(embedding)

def get_last_layer(self) -> nn.Module:
return self.classifier
60 changes: 60 additions & 0 deletions modyn/models/coreset_methods_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import nn


# acknowledgment: github.com/PatrickZH/DeepCore/
class EmbeddingRecorder(nn.Module):
def __init__(self, record_embedding: bool = False):
super().__init__()
self.record_embedding = record_embedding
self.embedding: Optional[torch.Tensor] = None

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
if self.record_embedding:
self.embedding = tensor
return tensor

def start_recording(self) -> None:
self.record_embedding = True

def end_recording(self) -> None:
self.record_embedding = False
self.embedding = None


class CoresetSupportingModule(nn.Module, ABC):
"""
This class is used to support some Coreset Methods.
Embeddings, here defined as the activation before the last layer, are often used to estimate the importance of
a point. To implement this class correctly, it is necessary to
- implement the get_last_layer method
- modify the forward pass so that the last layer embedding is recorded. For example, in a simple network like
x = self.fc1(input)
x = self.fc2(x)
output = self.fc3(x)
it must be modified as follows
x = self.fc1(input)
x = self.fc2(x)
x = self.embedding_recorder(x)
output = self.fc3(x)
"""

def __init__(self, record_embedding: bool = False) -> None:
super().__init__()
self.embedding_recorder = EmbeddingRecorder(record_embedding)

@property
def embedding(self) -> Optional[torch.Tensor]:
assert self.embedding_recorder is not None
return self.embedding_recorder.embedding

@abstractmethod
def get_last_layer(self) -> nn.Module:
"""
Returns the last layer. Used for example to obtain the pre-layer and post-layer dimensions of tensors
"""
raise NotImplementedError()
17 changes: 15 additions & 2 deletions modyn/models/dlrm/dlrm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any
from typing import Any, Optional

import numpy as np
import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule, EmbeddingRecorder
from modyn.models.dlrm.nn.factories import create_interaction
from modyn.models.dlrm.nn.parts import DlrmBottom, DlrmTop
from modyn.models.dlrm.utils.install_lib import install_cuda_extensions_if_not_present
Expand All @@ -16,7 +17,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class DlrmModel(nn.Module):
class DlrmModel(CoresetSupportingModule):
# pylint: disable=too-many-instance-attributes
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
super().__init__()
Expand Down Expand Up @@ -124,3 +125,15 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
numerical_input, self.reorder_categorical_input(categorical_input)
)
return self.top_model(from_bottom, bottom_mlp_output).squeeze()

# delegate the embedding handling to the top model
@property
def embedding(self) -> Optional[torch.Tensor]:
return self.top_model.embedding

@property
def embedding_recorder(self) -> EmbeddingRecorder:
return self.top_model.embedding_recorder

def get_last_layer(self) -> nn.Module:
return self.top_model.get_last_layer()
10 changes: 8 additions & 2 deletions modyn/models/dlrm/nn/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Sequence, Tuple

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from modyn.models.dlrm.nn.embeddings import Embeddings
from modyn.models.dlrm.nn.factories import create_embeddings, create_mlp
from modyn.models.dlrm.nn.interactions import Interaction
Expand Down Expand Up @@ -104,7 +105,7 @@ def forward(self, numerical_input, categorical_inputs) -> Tuple[torch.Tensor, Op
return torch.cat(bottom_output, dim=1), bottom_mlp_output


class DlrmTop(nn.Module):
class DlrmTop(CoresetSupportingModule):
def __init__(self, top_mlp_sizes: Sequence[int], interaction: Interaction, device: str, use_cpp_mlp: bool = False):
super().__init__()

Expand All @@ -127,4 +128,9 @@ def forward(self, bottom_output, bottom_mlp_output):
bottom_mlp_output (Tensor): with shape [batch_size, embedding_dim]
"""
interaction_output = self.interaction.interact(bottom_output, bottom_mlp_output)
return self.out(self.mlp(interaction_output))
mlp_output = self.mlp(interaction_output)
mlp_output = self.embedding_recorder(mlp_output)
return self.out(mlp_output)

def get_last_layer(self) -> nn.Module:
return self.out
8 changes: 6 additions & 2 deletions modyn/models/fmownet/fmownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn.functional as F
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn
from torchvision.models import densenet121

Expand All @@ -19,7 +20,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class FmowNetModel(nn.Module):
class FmowNetModel(CoresetSupportingModule):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.num_classes = num_classes
Expand All @@ -31,5 +32,8 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)

out = self.embedding_recorder(out)
return self.classifier(out)

def get_last_layer(self) -> nn.Module:
return self.classifier
38 changes: 36 additions & 2 deletions modyn/models/resnet18/resnet18.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
from typing import Any

from torchvision import models
import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import Tensor, nn
from torchvision.models.resnet import BasicBlock, ResNet


class ResNet18:
# pylint: disable-next=unused-argument
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
self.model = models.__dict__["resnet18"](**model_configuration)
self.model = ResNet18Modyn(model_configuration)
self.model.to(device)


# the following class is adapted from
# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py


class ResNet18Modyn(ResNet, CoresetSupportingModule):
def __init__(self, model_configuration: dict[str, Any]) -> None:
super().__init__(BasicBlock, [2, 2, 2, 2], **model_configuration)

def _forward_impl(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
# the following line is the only difference compared to the original implementation
x = self.embedding_recorder(x)
x = self.fc(x)

return x

def get_last_layer(self) -> nn.Module:
return self.fc
8 changes: 6 additions & 2 deletions modyn/models/yearbooknet/yearbooknet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from torch import nn


Expand All @@ -17,7 +18,7 @@ def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool)
self.model.to(device)


class YearbookNetModel(nn.Module):
class YearbookNetModel(CoresetSupportingModule):
def __init__(self, num_input_channels: int, num_classes: int) -> None:
super().__init__()
self.enc = nn.Sequential(
Expand All @@ -37,5 +38,8 @@ def conv_block(self, in_channels: int, out_channels: int) -> nn.Module:
def forward(self, data: torch.Tensor) -> torch.Tensor:
data = self.enc(data)
data = torch.mean(data, dim=(2, 3))

data = self.embedding_recorder(data)
return self.classifier(data)

def get_last_layer(self) -> nn.Module:
return self.classifier
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def trigger(self) -> tuple[int, int, int]:
tuple[int, int, int]: Trigger ID, how many keys are in the trigger, number of overall partitions
"""
# TODO(#276) Unify AbstractSelection Strategy and LocalDatasetWriter

trigger_id = self._next_trigger_id
total_keys_in_trigger = 0

Expand Down Expand Up @@ -404,7 +405,8 @@ def get_available_labels(self) -> list[int]:
database.session.query(SelectorStateMetadata.label)
.filter(
SelectorStateMetadata.pipeline_id == self._pipeline_id,
SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers
SelectorStateMetadata.seen_in_trigger_id < self._next_trigger_id,
SelectorStateMetadata.seen_in_trigger_id >= self._next_trigger_id - self.tail_triggers - 1
if self.tail_triggers is not None
else True,
)
Expand Down
Loading

0 comments on commit 0c1aa65

Please sign in to comment.