diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index cfee6d511..901aa517b 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -41,8 +41,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): assert set(downsampled_indexes) <= set(range(8)) -@pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): +def test_sample_shape_binary(dummy_system_config: ModynConfig): model = torch.nn.Linear(10, 1) downsampling_ratio = 50 per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -53,11 +52,8 @@ def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) - forward_outputs = model(data) - target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1) - if squeeze_dim: - target = target.squeeze(1) - forward_outputs = forward_outputs.squeeze(1) + forward_outputs = model(data).squeeze(1) + target = torch.randint(2, size=(8,), dtype=torch.float32) ids = list(range(8)) sampler.inform_samples(ids, data, forward_outputs, target) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py index 52bd9e124..8b4174b93 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py @@ -36,8 +36,7 @@ def test_sample_shape(dummy_system_config: ModynConfig): assert len(indexes) == 4 -@pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): +def test_sample_shape_binary(dummy_system_config: ModynConfig): model = torch.nn.Linear(10, 1) downsampling_ratio = 50 per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -48,11 +47,8 @@ def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) - forward_outputs = model(data) - target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1) - if squeeze_dim: - target = target.squeeze(1) - forward_outputs = forward_outputs.squeeze(1) + forward_outputs = model(data).squeeze(1) + target = torch.randint(2, size=(8,), dtype=torch.float32) ids = list(range(8)) sampler.inform_samples(ids, data, forward_outputs, target) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py index 4ba3445e2..d93539715 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py @@ -76,27 +76,24 @@ def test_compute_score(sampler_config): binary_test_data = { "LeastConfidence": { - "outputs": torch.tensor([[0.8], [0.5], [0.3]]), - "expected_scores": np.array([0.8, 0.5, 0.7]), # confidence just picks the highest probability + "outputs": torch.tensor([-0.8, 0.5, 0.3]), + "expected_scores": np.array([0.8, 0.5, 0.3]), # confidence just picks the highest probability }, "Entropy": { - "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "outputs": torch.tensor([0.8, 0.5, 0.3]), "expected_scores": np.array([-0.5004, -0.6931, -0.6109]), }, "Margin": { - "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "outputs": torch.tensor([0.8, 0.5, 0.3]), "expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes }, } -@pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_compute_score_binary(sampler_config, squeeze_dim): +def test_compute_score_binary(sampler_config): metric = sampler_config[3]["score_metric"] amds = RemoteUncertaintyDownsamplingStrategy(*sampler_config) outputs = binary_test_data[metric]["outputs"] - if squeeze_dim: - outputs = outputs.squeeze() expected_scores = binary_test_data[metric]["expected_scores"] scores = amds._compute_score(outputs, disable_softmax=True) assert np.allclose(scores, expected_scores, atol=1e-4) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index 4320a4666..b6a1c1b63 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -49,11 +49,14 @@ def inform_samples( target: torch.Tensor, embedding: torch.Tensor | None = None, ) -> None: + if forward_output.dim() == 1: + # BCEWithLogitsLoss requires that forward_output and target have the same shape + forward_output = forward_output.unsqueeze(1) + target = target.unsqueeze(1) + last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum( self.per_sample_loss_fct, forward_output, target ) - if last_layer_gradients.dim() == 1: - last_layer_gradients = last_layer_gradients.unsqueeze(1) # pylint: disable=not-callable scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu() self.probabilities.append(scores) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py index 239d86425..79bcee16c 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py @@ -52,8 +52,6 @@ def inform_samples( embedding: torch.Tensor | None = None, ) -> None: scores = self.get_scores(forward_output, target) - if scores.dim() == 2: - scores = scores.squeeze(1) self.probabilities.append(scores) self.number_of_points_seen += forward_output.shape[0] self.index_sampleid_map += sample_ids diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index 7daf340ba..b92c64b97 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -72,30 +72,37 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F if forward_output.dim() == 1: forward_output = forward_output.unsqueeze(1) feature_size = forward_output.size(1) - if feature_size == 1: - forward_output = torch.cat((1 - forward_output, forward_output), dim=1) + if self.score_metric == "LeastConfidence": - scores = forward_output.max(dim=1).values.cpu().numpy() - elif self.score_metric == "Entropy": - preds = ( - torch.nn.functional.softmax(forward_output, dim=1).cpu().numpy() - if not disable_softmax - else forward_output.cpu().numpy() - ) - scores = (np.log(preds + 1e-6) * preds).sum(axis=1) - elif self.score_metric == "Margin": - preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output - preds_argmax = torch.argmax(preds, dim=1) # gets top class - max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class - - # remove highest class from softmax output - preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0 - - preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class) - second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax] - scores = (max_preds - second_max_preds).cpu().numpy() + if feature_size == 1: + # for binary classification comparing how far away the element is from 0.5 after sigmoid layer + # is the same as comparing the absolute value of the element before sigmoid layer + scores = torch.abs(forward_output).squeeze(1).cpu().numpy() + else: + scores = forward_output.max(dim=1).values.cpu().numpy() else: - raise AssertionError("The required metric does not exist") + if feature_size == 1: + # for binary classification the softmax layer is reduced to sigmoid + preds = torch.sigmoid(forward_output) if not disable_softmax else forward_output + # we need to convert it to a 2D tensor with probabilities for both classes + preds = torch.cat((1 - preds, preds), dim=1) + else: + preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output + + if self.score_metric == "Entropy": + scores = (np.log(preds + 1e-6) * preds).sum(axis=1) + elif self.score_metric == "Margin": + preds_argmax = torch.argmax(preds, dim=1) # gets top class + max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class + + # remove highest class from softmax output + preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0 + + preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class) + second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax] + scores = (max_preds - second_max_preds).cpu().numpy() + else: + raise AssertionError("The required metric does not exist") return scores