Skip to content

Commit

Permalink
Fix or disabled pyre issues that show up after Pyre Update (#1438)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1438

Fix or disabled pyre issues that show up after update to Pyre

Reviewed By: jjuncho

Differential Revision: D65826206

fbshipit-source-id: 437d05308bb3439e8baa69172fa9408a1acbff60
  • Loading branch information
cyrjano authored and facebook-github-bot committed Nov 12, 2024
1 parent 6540e74 commit e3a3574
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 2 deletions.
2 changes: 2 additions & 0 deletions captum/attr/_core/dataloader_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def _perturb_inputs(
else:
baseline = baselines[attr_inp_count]

# pyre-fixme[58]: `*` is not supported for operand types `object` and
# `Tensor`.
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
perturbed_inputs.append(perturbed_inp)

Expand Down
1 change: 1 addition & 0 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def _construct_ablated_input(
dim=0,
).long()
current_mask = current_mask.to(expanded_input.device)
assert baseline is not None, "baseline must be provided"
ablated_tensor = (
expanded_input * (1 - current_mask).to(expanded_input.dtype)
) + (baseline * current_mask.to(expanded_input.dtype))
Expand Down
1 change: 1 addition & 0 deletions captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def _construct_ablated_input(
],
dim=0,
).long()
assert baseline is not None, "baseline should not be None"
ablated_tensor = (
expanded_input
* (
Expand Down
2 changes: 2 additions & 0 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ def _next_infidelity_tensors(
inputs_fwd = torch.repeat_interleave(
inputs_fwd, current_n_perturb_samples, dim=0
)
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
# `Union[Future[Tensor], Tensor]`.
perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd
attributions_expanded = tuple(
torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0)
Expand Down
4 changes: 3 additions & 1 deletion captum/module/gaussian_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def _get_gate_active_probs(self) -> Tensor:
probs (Tensor): probabilities tensor of the gates are active
in shape(n_gates)
"""
x = self.mu / self.std
std = self.std
assert std is not None, "std should not be None"
x = self.mu / std
return 0.5 * (1 + torch.erf(x / math.sqrt(2)))

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion tests/attr/test_input_x_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def _input_x_gradient_classification_assert(self, nt_type: str = "vanilla") -> N
attributions = input_x_grad.attribute(input, target)
output = model(input)[:, target]
output.backward()
expected = input.grad * input
input_grad = input.grad
assert input_grad is not None
expected = input_grad * input
assertTensorAlmostEqual(self, attributions, expected, 0.00001, "max")
else:
nt = NoiseTunnel(input_x_grad)
Expand Down

0 comments on commit e3a3574

Please sign in to comment.