From e8b6d98a23feee9af444f48807aa6902952bd7c4 Mon Sep 17 00:00:00 2001 From: Christy Sauper Date: Thu, 31 Oct 2024 10:48:10 -0700 Subject: [PATCH] Fix more influence functions [6/n] (#1428) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1428 additional pyre fixes in `influence_function.py` Reviewed By: cyrjano Differential Revision: D65256481 fbshipit-source-id: 4fab2466b0e7072fc161e3185e09bf87bd4c58c1 --- captum/influence/_core/influence_function.py | 43 +++++++++----------- captum/influence/_core/tracincp.py | 3 +- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/captum/influence/_core/influence_function.py b/captum/influence/_core/influence_function.py index 6fdaf4a93..1c44b731c 100644 --- a/captum/influence/_core/influence_function.py +++ b/captum/influence/_core/influence_function.py @@ -35,7 +35,7 @@ KMostInfluentialResults, ) from captum.log import log_usage -from torch import Tensor +from torch import device, Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset from tqdm import tqdm @@ -69,8 +69,9 @@ def __init__( model: Module, train_dataset: Union[Dataset, DataLoader], checkpoint: str, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - checkpoints_load_func: Callable = _load_flexible_state_dict, + checkpoints_load_func: Callable[ + [Module, str], float + ] = _load_flexible_state_dict, layers: Optional[List[str]] = None, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Optional[Union[Module, Callable]] = None, @@ -243,21 +244,19 @@ def __init__( # `loss_fn` is needed to compute the Hessian. # check `loss_fn` - # pyre-fixme[4]: Attribute must be annotated. - self.reduction_type = _check_loss_fn( + self.reduction_type: str = _check_loss_fn( self, loss_fn, "loss_fn", sample_wise_grads_per_batch ) # check `test_loss_fn` if it was provided + self.test_reduction_type: str = "" if not (test_loss_fn is None): - # pyre-fixme[4]: Attribute must be annotated. self.test_reduction_type = _check_loss_fn( self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch ) else: self.test_reduction_type = self.reduction_type - # pyre-fixme[4]: Attribute must be annotated. - self.layer_modules = None + self.layer_modules: Optional[List[Module]] = None if not (layers is None): self.layer_modules = _set_active_parameters(model, layers) @@ -703,8 +702,9 @@ def __init__( model: Module, train_dataset: Union[Dataset, DataLoader], checkpoint: str, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - checkpoints_load_func: Callable = _load_flexible_state_dict, + checkpoints_load_func: Callable[ + [Module, str], float + ] = _load_flexible_state_dict, layers: Optional[List[str]] = None, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Optional[Union[Module, Callable]] = None, @@ -869,11 +869,9 @@ def __init__( # infer the device the model is on. all parameters are assumed to be on the # same device - # pyre-fixme[4]: Attribute must be annotated. - self.model_device = next(model.parameters()).device + self.model_device: device = next(model.parameters()).device - # pyre-fixme[4]: Attribute must be annotated. - self.R = self._retrieve_projections_naive_influence_function( + self.R: Tensor = self._retrieve_projections_naive_influence_function( self.hessian_dataloader, projection_on_cpu, show_progress, @@ -1035,11 +1033,11 @@ def compute_intermediate_quantities( inputs_dataset = _progress_bar_constructor( self, inputs_dataset, "inputs_dataset", "intermediate quantities" ) + # infer model / data device through model - if return_on_cpu is None or (not return_on_cpu): - return_device = self.model_device - else: - return_device = torch.device("cpu") + return_device: device = ( + torch.device("cpu") if return_on_cpu else self.model_device + ) # as described in the description for `NaiveInfluenceFunction`, the embedding # for an example `x` is :math`\nabla_\theta L(x)' R`. @@ -1051,15 +1049,12 @@ def compute_intermediate_quantities( # choose the correct loss function and reduction type based on `test` loss_fn = self.test_loss_fn if test else self.loss_fn - reduction_type = self.test_reduction_type if test else self.reduction_type + reduction_type: str = self.test_reduction_type if test else self.reduction_type # define a helper function that returns the embeddings for a batch # pyre-fixme[53]: Captured variable `loss_fn` is not annotated. - # pyre-fixme[53]: Captured variable `reduction_type` is not annotated. - # pyre-fixme[53]: Captured variable `return_device` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def get_batch_embeddings(batch): + def get_batch_embeddings(batch: Tuple[Tensor, ...]) -> Tensor: + nonlocal loss_fn, reduction_type, return_device # if `self.R` is on cpu, and `self.model_device` was not cpu, this implies # `self.R` was too large to fit in gpu memory, and we should do the matrix # multiplication of the batch jacobians with `self.R` separately for each diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index 5c9d9cfaa..f36d38323 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -1222,8 +1222,7 @@ def _self_influence_by_checkpoints( stacklevel=1, ) - # pyre-fixme[2]: Parameter must be annotated. - def calculate_via_vector_norm(layer_jacobian) -> Tensor: + def calculate_via_vector_norm(layer_jacobian: Tensor) -> Tensor: # Helper to efficiently calculate vector norm if pytorch version permits. return ( torch.linalg.vector_norm(