-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce policies from DeepCore (#304)
With this PR, we introduce several new downsampling methods: `Craig`, `GradMatch`, `Submodular` (FacilityLocation, GraphCut and LogDeterminant), `KcenterGreedy` and `Uncertainty` (Margin, Entropy, and LeastConfidence). These methods are adapted from [DEEPCORE](https://github.com/PatrickZH/DeepCore) The following functionalities are implemented: - `CoresetSupportingModule` support in both BTS and STB mode. Embeddings are registered and provided in the `inform_samples` method if required. - addition of the `device` in the constructor of the downsampling methods. For example, kcenter uses a different implementation depending on whether it runs on CPU or GPU. - implementation of the techniques mentioned above. Many of which, having a common behaviour, are abstracted with the `AbstractMatrixDownsamplingStrategy` class - fixing some bugs found by running the experiments (typically moving weights to the correct device) All methods were tested by comparing the results obtained with deepcore in various controlled experiments. The available tests serve this purpose.
- Loading branch information
1 parent
73f9c0a
commit cc139e0
Showing
46 changed files
with
2,782 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ dependencies: | |
- numpy | ||
- pandas | ||
- tensorboard | ||
- scipy | ||
- pyftpdlib | ||
- types-protobuf | ||
- types-psycopg2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
modyn/selector/internal/selector_strategies/downsampling_strategies/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
...ector/internal/selector_strategies/downsampling_strategies/craig_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy | ||
from modyn.utils import DownsamplingMode | ||
|
||
|
||
class CraigDownsamplingStrategy(AbstractDownsamplingStrategy): | ||
def __init__(self, downsampling_config: dict, maximum_keys_in_memory: int): | ||
super().__init__(downsampling_config, maximum_keys_in_memory) | ||
|
||
self.remote_downsampling_strategy_name = "RemoteCraigDownsamplingStrategy" | ||
|
||
def _build_downsampling_params(self) -> dict: | ||
config = super()._build_downsampling_params() | ||
config["selection_batch"] = self.downsampling_config.get("selection_batch", 64) | ||
config["balance"] = self.downsampling_config.get("balance", False) | ||
config["greedy"] = self.downsampling_config.get("greedy", "NaiveGreedy") | ||
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: | ||
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") | ||
return config |
16 changes: 16 additions & 0 deletions
16
...r/internal/selector_strategies/downsampling_strategies/gradmatch_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy | ||
from modyn.utils import DownsamplingMode | ||
|
||
|
||
class GradMatchDownsamplingStrategy(AbstractDownsamplingStrategy): | ||
def __init__(self, downsampling_config: dict, maximum_keys_in_memory: int): | ||
super().__init__(downsampling_config, maximum_keys_in_memory) | ||
|
||
self.remote_downsampling_strategy_name = "RemoteGradMatchDownsamplingStrategy" | ||
|
||
def _build_downsampling_params(self) -> dict: | ||
config = super()._build_downsampling_params() | ||
config["balance"] = self.downsampling_config.get("balance", False) | ||
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: | ||
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") | ||
return config |
16 changes: 16 additions & 0 deletions
16
...ternal/selector_strategies/downsampling_strategies/kcentergreedy_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy | ||
from modyn.utils import DownsamplingMode | ||
|
||
|
||
class KcenterGreedyDownsamplingStrategy(AbstractDownsamplingStrategy): | ||
def __init__(self, downsampling_config: dict, maximum_keys_in_memory: int): | ||
super().__init__(downsampling_config, maximum_keys_in_memory) | ||
|
||
self.remote_downsampling_strategy_name = "RemoteKcenterGreedyDownsamplingStrategy" | ||
|
||
def _build_downsampling_params(self) -> dict: | ||
config = super()._build_downsampling_params() | ||
config["balance"] = self.downsampling_config.get("balance", False) | ||
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: | ||
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") | ||
return config |
33 changes: 33 additions & 0 deletions
33
.../internal/selector_strategies/downsampling_strategies/submodular_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy | ||
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.submodular_function import ( | ||
SUBMODULAR_FUNCTIONS, | ||
) | ||
from modyn.utils import DownsamplingMode | ||
|
||
|
||
class SubmodularDownsamplingStrategy(AbstractDownsamplingStrategy): | ||
def __init__(self, downsampling_config: dict, maximum_keys_in_memory: int): | ||
super().__init__(downsampling_config, maximum_keys_in_memory) | ||
|
||
self.remote_downsampling_strategy_name = "RemoteSubmodularDownsamplingStrategy" | ||
|
||
def _build_downsampling_params(self) -> dict: | ||
config = super()._build_downsampling_params() | ||
|
||
if "submodular_function" not in self.downsampling_config: | ||
raise ValueError( | ||
f"Please specify the submodular function used to select the datapoints. " | ||
f"Available functions: {SUBMODULAR_FUNCTIONS}, param submodular_function" | ||
) | ||
config["submodular_function"] = self.downsampling_config["submodular_function"] | ||
|
||
if "submodular_optimizer" in self.downsampling_config: | ||
config["submodular_optimizer"] = self.downsampling_config["submodular_optimizer"] | ||
|
||
config["selection_batch"] = self.downsampling_config.get("selection_batch", 64) | ||
|
||
config["balance"] = self.downsampling_config.get("balance", False) | ||
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: | ||
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") | ||
|
||
return config |
26 changes: 26 additions & 0 deletions
26
...internal/selector_strategies/downsampling_strategies/uncertainty_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy | ||
from modyn.utils import DownsamplingMode | ||
|
||
|
||
class UncertaintyDownsamplingStrategy(AbstractDownsamplingStrategy): | ||
def __init__(self, downsampling_config: dict, maximum_keys_in_memory: int): | ||
super().__init__(downsampling_config, maximum_keys_in_memory) | ||
|
||
self.remote_downsampling_strategy_name = "RemoteUncertaintyDownsamplingStrategy" | ||
|
||
def _build_downsampling_params(self) -> dict: | ||
config = super()._build_downsampling_params() | ||
|
||
if "score_metric" not in self.downsampling_config: | ||
raise ValueError( | ||
"Please specify the metric used to score uncertainty for the datapoints. " | ||
"Available metrics : LeastConfidence, Entropy, Margin" | ||
"Use the pipeline parameter score_metric" | ||
) | ||
config["score_metric"] = self.downsampling_config["score_metric"] | ||
|
||
config["balance"] = self.downsampling_config.get("balance", False) | ||
if config["balance"] and self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: | ||
raise ValueError("Balanced sampling (balance=True) can be used only in Sample then Batch mode.") | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
...ts/trainer_server/internal/trainer/remote_downsamplers/deepcore_comparison_tests_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
import torch | ||
from modyn.models.coreset_methods_support import CoresetSupportingModule | ||
from torch import nn | ||
|
||
|
||
class DummyModel(CoresetSupportingModule): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.hidden_layer = nn.Linear(in_features=1, out_features=10) | ||
self.output_layer = nn.Linear(in_features=10, out_features=1) | ||
|
||
def forward(self, input_tensor): | ||
input_tensor = torch.relu(self.hidden_layer(input_tensor)) | ||
input_tensor = self.embedding_recorder(input_tensor) | ||
outputs = self.output_layer(input_tensor) | ||
return outputs | ||
|
||
def get_last_layer(self): | ||
return self.output_layer | ||
|
||
|
||
def assert_close_matrices(matrix1, matrix2): | ||
for row1, row2 in zip(matrix1, matrix2): | ||
assert len(row1) == len(row2) | ||
for el1, el2 in zip(row1, row2): | ||
assert np.isclose(el1, el2, 1e-2) |
123 changes: 123 additions & 0 deletions
123
...server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# pylint: disable=abstract-class-instantiated,unused-argument | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
import torch | ||
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_matrix_downsampling_strategy import ( | ||
AbstractMatrixDownsamplingStrategy, | ||
MatrixContent, | ||
) | ||
|
||
|
||
def get_sampler_config(balance=False): | ||
downsampling_ratio = 50 | ||
per_sample_loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | ||
|
||
params_from_selector = { | ||
"downsampling_ratio": downsampling_ratio, | ||
"sample_then_batch": False, | ||
"args": {}, | ||
"balance": balance, | ||
} | ||
return 0, 0, 0, params_from_selector, per_sample_loss_fct, "cpu" | ||
|
||
|
||
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) | ||
def test_init(): | ||
amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config()) | ||
|
||
assert amds.requires_coreset_supporting_module | ||
assert not amds.matrix_elements | ||
assert amds.matrix_content is None | ||
|
||
|
||
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) | ||
def test_collect_embeddings(): | ||
amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config()) | ||
|
||
amds.matrix_content = MatrixContent.EMBEDDINGS | ||
|
||
assert amds.requires_coreset_supporting_module | ||
assert not amds.matrix_elements # thank you pylint! amds.matrix_elements == [] | ||
|
||
first_embedding = torch.randn((4, 5)) | ||
second_embedding = torch.randn((3, 5)) | ||
amds.inform_samples([1, 2, 3, 4], None, None, first_embedding) | ||
amds.inform_samples([21, 31, 41], None, None, second_embedding) | ||
|
||
assert np.concatenate(amds.matrix_elements).shape == (7, 5) | ||
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding])) | ||
assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] | ||
|
||
third_embedding = torch.randn((23, 5)) | ||
amds.inform_samples(list(range(1000, 1023)), None, None, third_embedding) | ||
|
||
assert np.concatenate(amds.matrix_elements).shape == (30, 5) | ||
assert all( | ||
torch.equal(el1, el2) | ||
for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding, third_embedding]) | ||
) | ||
assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] + list(range(1000, 1023)) | ||
|
||
|
||
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) | ||
@patch.object( | ||
AbstractMatrixDownsamplingStrategy, "_select_indexes_from_matrix", return_value=([0, 2], torch.Tensor([1.0, 3.0])) | ||
) | ||
def test_collect_embedding_balance(test_amds): | ||
amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config(True)) | ||
|
||
amds.matrix_content = MatrixContent.EMBEDDINGS | ||
|
||
assert amds.requires_coreset_supporting_module | ||
assert amds.requires_data_label_by_label | ||
assert not amds.matrix_elements # thank you pylint! amds.matrix_elements == [] | ||
|
||
first_embedding = torch.randn((4, 5)) | ||
second_embedding = torch.randn((3, 5)) | ||
amds.inform_samples([1, 2, 3, 4], None, None, first_embedding) | ||
amds.inform_samples([21, 31, 41], None, None, second_embedding) | ||
|
||
assert np.concatenate(amds.matrix_elements).shape == (7, 5) | ||
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding])) | ||
assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] | ||
|
||
amds.inform_end_of_current_label() | ||
|
||
third_embedding = torch.randn((23, 5)) | ||
assert len(amds.matrix_elements) == 0 | ||
amds.inform_samples(list(range(1000, 1023)), None, None, third_embedding) | ||
|
||
assert np.concatenate(amds.matrix_elements).shape == (23, 5) | ||
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [third_embedding])) | ||
assert amds.index_sampleid_map == list(range(1000, 1023)) | ||
assert amds.already_selected_samples == [1, 3] | ||
amds.inform_end_of_current_label() | ||
assert amds.already_selected_samples == [1, 3, 1000, 1002] | ||
|
||
|
||
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) | ||
def test_collect_gradients(): | ||
amds = AbstractMatrixDownsamplingStrategy(*get_sampler_config()) | ||
amds.matrix_content = MatrixContent.GRADIENTS | ||
|
||
first_output = torch.randn((4, 2)) | ||
first_output.requires_grad = True | ||
first_target = torch.tensor([1, 1, 1, 0]) | ||
first_embedding = torch.randn((4, 5)) | ||
amds.inform_samples([1, 2, 3, 4], first_output, first_target, first_embedding) | ||
|
||
second_output = torch.randn((3, 2)) | ||
second_output.requires_grad = True | ||
second_target = torch.tensor([0, 1, 0]) | ||
second_embedding = torch.randn((3, 5)) | ||
amds.inform_samples([21, 31, 41], second_output, second_target, second_embedding) | ||
|
||
assert len(amds.matrix_elements) == 2 | ||
|
||
# expected shape = (a,b) | ||
# a = 7 (4 samples in the first batch and 3 samples in the second batch) | ||
# b = 5 * 2 + 2 where 5 is the input dimension of the last layer and 2 is the output one | ||
assert np.concatenate(amds.matrix_elements).shape == (7, 12) | ||
|
||
assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.