Skip to content

Commit

Permalink
refix
Browse files Browse the repository at this point in the history
  • Loading branch information
XianzheMa committed Sep 7, 2024
1 parent b01f3bc commit d08beda
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig):
assert set(downsampled_indexes) <= set(range(8))


@pytest.mark.parametrize("squeeze_dim", [True, False])
def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim):
def test_sample_shape_binary(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 1)
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
Expand All @@ -53,11 +52,8 @@ def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim):
)
with torch.inference_mode(mode=(not sampler.requires_grad)):
data = torch.randn(8, 10)
forward_outputs = model(data)
target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1)
if squeeze_dim:
target = target.squeeze(1)
forward_outputs = forward_outputs.squeeze(1)
forward_outputs = model(data).squeeze(1)
target = torch.randint(2, size=(8,), dtype=torch.float32)
ids = list(range(8))

sampler.inform_samples(ids, data, forward_outputs, target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def test_sample_shape(dummy_system_config: ModynConfig):
assert len(indexes) == 4


@pytest.mark.parametrize("squeeze_dim", [True, False])
def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim):
def test_sample_shape_binary(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 1)
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
Expand All @@ -48,11 +47,8 @@ def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim):
)
with torch.inference_mode(mode=(not sampler.requires_grad)):
data = torch.randn(8, 10)
forward_outputs = model(data)
target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1)
if squeeze_dim:
target = target.squeeze(1)
forward_outputs = forward_outputs.squeeze(1)
forward_outputs = model(data).squeeze(1)
target = torch.randint(2, size=(8,), dtype=torch.float32)
ids = list(range(8))

sampler.inform_samples(ids, data, forward_outputs, target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,27 +76,24 @@ def test_compute_score(sampler_config):

binary_test_data = {
"LeastConfidence": {
"outputs": torch.tensor([[0.8], [0.5], [0.3]]),
"expected_scores": np.array([0.8, 0.5, 0.7]), # confidence just picks the highest probability
"outputs": torch.tensor([-0.8, 0.5, 0.3]),
"expected_scores": np.array([0.8, 0.5, 0.3]), # confidence just picks the highest probability
},
"Entropy": {
"outputs": torch.tensor([[0.8], [0.5], [0.3]]),
"outputs": torch.tensor([0.8, 0.5, 0.3]),
"expected_scores": np.array([-0.5004, -0.6931, -0.6109]),
},
"Margin": {
"outputs": torch.tensor([[0.8], [0.5], [0.3]]),
"outputs": torch.tensor([0.8, 0.5, 0.3]),
"expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes
},
}


@pytest.mark.parametrize("squeeze_dim", [True, False])
def test_compute_score_binary(sampler_config, squeeze_dim):
def test_compute_score_binary(sampler_config):
metric = sampler_config[3]["score_metric"]
amds = RemoteUncertaintyDownsamplingStrategy(*sampler_config)
outputs = binary_test_data[metric]["outputs"]
if squeeze_dim:
outputs = outputs.squeeze()
expected_scores = binary_test_data[metric]["expected_scores"]
scores = amds._compute_score(outputs, disable_softmax=True)
assert np.allclose(scores, expected_scores, atol=1e-4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ def inform_samples(
target: torch.Tensor,
embedding: torch.Tensor | None = None,
) -> None:
if forward_output.dim() == 1:
# BCEWithLogitsLoss requires that forward_output and target have the same shape
forward_output = forward_output.unsqueeze(1)
target = target.unsqueeze(1)

last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum(
self.per_sample_loss_fct, forward_output, target
)
if last_layer_gradients.dim() == 1:
last_layer_gradients = last_layer_gradients.unsqueeze(1)
# pylint: disable=not-callable
scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu()
self.probabilities.append(scores)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def inform_samples(
embedding: torch.Tensor | None = None,
) -> None:
scores = self.get_scores(forward_output, target)
if scores.dim() == 2:
scores = scores.squeeze(1)
self.probabilities.append(scores)
self.number_of_points_seen += forward_output.shape[0]
self.index_sampleid_map += sample_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,37 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F
if forward_output.dim() == 1:
forward_output = forward_output.unsqueeze(1)
feature_size = forward_output.size(1)
if feature_size == 1:
forward_output = torch.cat((1 - forward_output, forward_output), dim=1)

if self.score_metric == "LeastConfidence":
scores = forward_output.max(dim=1).values.cpu().numpy()
elif self.score_metric == "Entropy":
preds = (
torch.nn.functional.softmax(forward_output, dim=1).cpu().numpy()
if not disable_softmax
else forward_output.cpu().numpy()
)
scores = (np.log(preds + 1e-6) * preds).sum(axis=1)
elif self.score_metric == "Margin":
preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output
preds_argmax = torch.argmax(preds, dim=1) # gets top class
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class

# remove highest class from softmax output
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0

preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class)
second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]
scores = (max_preds - second_max_preds).cpu().numpy()
if feature_size == 1:
# for binary classification comparing how far away the element is from 0.5 after sigmoid layer
# is the same as comparing the absolute value of the element before sigmoid layer
scores = torch.abs(forward_output).squeeze(1).cpu().numpy()
else:
scores = forward_output.max(dim=1).values.cpu().numpy()
else:
raise AssertionError("The required metric does not exist")
if feature_size == 1:
# for binary classification the softmax layer is reduced to sigmoid
preds = torch.sigmoid(forward_output) if not disable_softmax else forward_output
# we need to convert it to a 2D tensor with probabilities for both classes
preds = torch.cat((1 - preds, preds), dim=1)
else:
preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output

if self.score_metric == "Entropy":
scores = (np.log(preds + 1e-6) * preds).sum(axis=1)
elif self.score_metric == "Margin":
preds_argmax = torch.argmax(preds, dim=1) # gets top class
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class

# remove highest class from softmax output
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0

preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class)
second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]
scores = (max_preds - second_max_preds).cpu().numpy()
else:
raise AssertionError("The required metric does not exist")

Check warning on line 105 in modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py

View check run for this annotation

Codecov / codecov/patch

modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py#L105

Added line #L105 was not covered by tests

return scores

Expand Down

0 comments on commit d08beda

Please sign in to comment.