-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce CoresetMethodsSupport (#294)
Many DeepCore methods require obtaining the last layer and using an embedding recorder. This PR introduces the CoresetSupportingModule class that must be inherited to use these methods.
- Loading branch information
1 parent
013e7af
commit 0c1aa65
Showing
17 changed files
with
352 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
# acknowledgment: github.com/PatrickZH/DeepCore/ | ||
class EmbeddingRecorder(nn.Module): | ||
def __init__(self, record_embedding: bool = False): | ||
super().__init__() | ||
self.record_embedding = record_embedding | ||
self.embedding: Optional[torch.Tensor] = None | ||
|
||
def forward(self, tensor: torch.Tensor) -> torch.Tensor: | ||
if self.record_embedding: | ||
self.embedding = tensor | ||
return tensor | ||
|
||
def start_recording(self) -> None: | ||
self.record_embedding = True | ||
|
||
def end_recording(self) -> None: | ||
self.record_embedding = False | ||
self.embedding = None | ||
|
||
|
||
class CoresetSupportingModule(nn.Module, ABC): | ||
""" | ||
This class is used to support some Coreset Methods. | ||
Embeddings, here defined as the activation before the last layer, are often used to estimate the importance of | ||
a point. To implement this class correctly, it is necessary to | ||
- implement the get_last_layer method | ||
- modify the forward pass so that the last layer embedding is recorded. For example, in a simple network like | ||
x = self.fc1(input) | ||
x = self.fc2(x) | ||
output = self.fc3(x) | ||
it must be modified as follows | ||
x = self.fc1(input) | ||
x = self.fc2(x) | ||
x = self.embedding_recorder(x) | ||
output = self.fc3(x) | ||
""" | ||
|
||
def __init__(self, record_embedding: bool = False) -> None: | ||
super().__init__() | ||
self.embedding_recorder = EmbeddingRecorder(record_embedding) | ||
|
||
@property | ||
def embedding(self) -> Optional[torch.Tensor]: | ||
assert self.embedding_recorder is not None | ||
return self.embedding_recorder.embedding | ||
|
||
@abstractmethod | ||
def get_last_layer(self) -> nn.Module: | ||
""" | ||
Returns the last layer. Used for example to obtain the pre-layer and post-layer dimensions of tensors | ||
""" | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,44 @@ | ||
from typing import Any | ||
|
||
from torchvision import models | ||
import torch | ||
from modyn.models.coreset_methods_support import CoresetSupportingModule | ||
from torch import Tensor, nn | ||
from torchvision.models.resnet import BasicBlock, ResNet | ||
|
||
|
||
class ResNet18: | ||
# pylint: disable-next=unused-argument | ||
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: | ||
self.model = models.__dict__["resnet18"](**model_configuration) | ||
self.model = ResNet18Modyn(model_configuration) | ||
self.model.to(device) | ||
|
||
|
||
# the following class is adapted from | ||
# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py | ||
|
||
|
||
class ResNet18Modyn(ResNet, CoresetSupportingModule): | ||
def __init__(self, model_configuration: dict[str, Any]) -> None: | ||
super().__init__(BasicBlock, [2, 2, 2, 2], **model_configuration) | ||
|
||
def _forward_impl(self, x: Tensor) -> Tensor: | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
|
||
x = self.avgpool(x) | ||
x = torch.flatten(x, 1) | ||
# the following line is the only difference compared to the original implementation | ||
x = self.embedding_recorder(x) | ||
x = self.fc(x) | ||
|
||
return x | ||
|
||
def get_last_layer(self) -> nn.Module: | ||
return self.fc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.