Skip to content

Commit

Permalink
Fix more influence functions [6/n] (#1428)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1428

additional pyre fixes in `influence_function.py`

Reviewed By: cyrjano

Differential Revision: D65256481

fbshipit-source-id: 4fab2466b0e7072fc161e3185e09bf87bd4c58c1
  • Loading branch information
csauper authored and facebook-github-bot committed Oct 31, 2024
1 parent d28efcd commit e8b6d98
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 26 deletions.
43 changes: 19 additions & 24 deletions captum/influence/_core/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
KMostInfluentialResults,
)
from captum.log import log_usage
from torch import Tensor
from torch import device, Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
Expand Down Expand Up @@ -69,8 +69,9 @@ def __init__(
model: Module,
train_dataset: Union[Dataset, DataLoader],
checkpoint: str,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
checkpoints_load_func: Callable = _load_flexible_state_dict,
checkpoints_load_func: Callable[
[Module, str], float
] = _load_flexible_state_dict,
layers: Optional[List[str]] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
loss_fn: Optional[Union[Module, Callable]] = None,
Expand Down Expand Up @@ -243,21 +244,19 @@ def __init__(
# `loss_fn` is needed to compute the Hessian.

# check `loss_fn`
# pyre-fixme[4]: Attribute must be annotated.
self.reduction_type = _check_loss_fn(
self.reduction_type: str = _check_loss_fn(
self, loss_fn, "loss_fn", sample_wise_grads_per_batch
)
# check `test_loss_fn` if it was provided
self.test_reduction_type: str = ""
if not (test_loss_fn is None):
# pyre-fixme[4]: Attribute must be annotated.
self.test_reduction_type = _check_loss_fn(
self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch
)
else:
self.test_reduction_type = self.reduction_type

# pyre-fixme[4]: Attribute must be annotated.
self.layer_modules = None
self.layer_modules: Optional[List[Module]] = None
if not (layers is None):
self.layer_modules = _set_active_parameters(model, layers)

Expand Down Expand Up @@ -703,8 +702,9 @@ def __init__(
model: Module,
train_dataset: Union[Dataset, DataLoader],
checkpoint: str,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
checkpoints_load_func: Callable = _load_flexible_state_dict,
checkpoints_load_func: Callable[
[Module, str], float
] = _load_flexible_state_dict,
layers: Optional[List[str]] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
loss_fn: Optional[Union[Module, Callable]] = None,
Expand Down Expand Up @@ -869,11 +869,9 @@ def __init__(

# infer the device the model is on. all parameters are assumed to be on the
# same device
# pyre-fixme[4]: Attribute must be annotated.
self.model_device = next(model.parameters()).device
self.model_device: device = next(model.parameters()).device

# pyre-fixme[4]: Attribute must be annotated.
self.R = self._retrieve_projections_naive_influence_function(
self.R: Tensor = self._retrieve_projections_naive_influence_function(
self.hessian_dataloader,
projection_on_cpu,
show_progress,
Expand Down Expand Up @@ -1035,11 +1033,11 @@ def compute_intermediate_quantities(
inputs_dataset = _progress_bar_constructor(
self, inputs_dataset, "inputs_dataset", "intermediate quantities"
)

# infer model / data device through model
if return_on_cpu is None or (not return_on_cpu):
return_device = self.model_device
else:
return_device = torch.device("cpu")
return_device: device = (
torch.device("cpu") if return_on_cpu else self.model_device
)

# as described in the description for `NaiveInfluenceFunction`, the embedding
# for an example `x` is :math`\nabla_\theta L(x)' R`.
Expand All @@ -1051,15 +1049,12 @@ def compute_intermediate_quantities(

# choose the correct loss function and reduction type based on `test`
loss_fn = self.test_loss_fn if test else self.loss_fn
reduction_type = self.test_reduction_type if test else self.reduction_type
reduction_type: str = self.test_reduction_type if test else self.reduction_type

# define a helper function that returns the embeddings for a batch
# pyre-fixme[53]: Captured variable `loss_fn` is not annotated.
# pyre-fixme[53]: Captured variable `reduction_type` is not annotated.
# pyre-fixme[53]: Captured variable `return_device` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_batch_embeddings(batch):
def get_batch_embeddings(batch: Tuple[Tensor, ...]) -> Tensor:
nonlocal loss_fn, reduction_type, return_device
# if `self.R` is on cpu, and `self.model_device` was not cpu, this implies
# `self.R` was too large to fit in gpu memory, and we should do the matrix
# multiplication of the batch jacobians with `self.R` separately for each
Expand Down
3 changes: 1 addition & 2 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,8 +1222,7 @@ def _self_influence_by_checkpoints(
stacklevel=1,
)

# pyre-fixme[2]: Parameter must be annotated.
def calculate_via_vector_norm(layer_jacobian) -> Tensor:
def calculate_via_vector_norm(layer_jacobian: Tensor) -> Tensor:
# Helper to efficiently calculate vector norm if pytorch version permits.
return (
torch.linalg.vector_norm(
Expand Down

0 comments on commit e8b6d98

Please sign in to comment.