diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 2470ae0c1..12784bfaf 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -12,6 +12,7 @@ Dict, List, Literal, + Optional, overload, Sequence, Tuple, @@ -272,28 +273,9 @@ def _format_float_or_tensor_into_tuples( return inputs -@overload -def _format_additional_forward_args(additional_forward_args: None) -> None: ... - - -@overload def _format_additional_forward_args( - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - additional_forward_args: Union[Tensor, Tuple] - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Tuple: ... - - -@overload -def _format_additional_forward_args( # type: ignore - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any, - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Union[None, Tuple]: ... - - -# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: + additional_forward_args: Optional[object], +) -> Union[None, Tuple[object, ...]]: if additional_forward_args is not None and not isinstance( additional_forward_args, tuple ): @@ -853,8 +835,7 @@ def _register_backward_hook( module: Module, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. hook: Callable, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - attr_obj: Any, + attr_obj: Union[object, None], ) -> List[torch.utils.hooks.RemovableHandle]: grad_out: Dict[device, Tensor] = {} @@ -864,10 +845,9 @@ def forward_hook( out: Union[Tensor, Tuple[Tensor, ...]], ) -> None: nonlocal grad_out - grad_out = {} - # pyre-fixme[53]: Captured variable `grad_out` is not annotated. def output_tensor_hook(output_grad: Tensor) -> None: + nonlocal grad_out grad_out[output_grad.device] = output_grad if isinstance(out, tuple): @@ -878,18 +858,19 @@ def output_tensor_hook(output_grad: Tensor) -> None: else: out.register_hook(output_tensor_hook) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def pre_hook(module, inp): - # pyre-fixme[53]: Captured variable `module` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - def input_tensor_hook(input_grad: Tensor): + def pre_hook(module: Module, inp: Union[Tensor, Tuple[Tensor, ...]]) -> Tensor: + def input_tensor_hook( + input_grad: Tensor, + ) -> Union[None, Tensor, Tuple[Tensor, ...]]: + nonlocal grad_out + if len(grad_out) == 0: - return + return None hook_out = hook(module, input_grad, grad_out[input_grad.device]) if hook_out is not None: return hook_out[0] if isinstance(hook_out, tuple) else hook_out + return None if isinstance(inp, tuple): assert (