From 8e23bb5ded66c06aeb64416b311b92cee688a697 Mon Sep 17 00:00:00 2001 From: Christy Sauper Date: Mon, 28 Oct 2024 12:46:13 -0700 Subject: [PATCH] Pyre fixes for common.py [3/n] (#1424) Summary: Rewriting from D64259572 after BE week Reviewed By: cyrjano Differential Revision: D65011997 --- captum/_utils/common.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 2470ae0c1..b3f6db1e8 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -278,22 +278,13 @@ def _format_additional_forward_args(additional_forward_args: None) -> None: ... @overload def _format_additional_forward_args( - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - additional_forward_args: Union[Tensor, Tuple] - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Tuple: ... + additional_forward_args: Union[object, Tuple[object, ...]] +) -> Tuple[object, ...]: ... -@overload -def _format_additional_forward_args( # type: ignore - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any, - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Union[None, Tuple]: ... - - -# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: +def _format_additional_forward_args( + additional_forward_args: Union[object, Tuple[object, ...], None] +) -> Union[None, Tuple[object, ...]]: if additional_forward_args is not None and not isinstance( additional_forward_args, tuple ): @@ -853,8 +844,7 @@ def _register_backward_hook( module: Module, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. hook: Callable, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - attr_obj: Any, + attr_obj: Union[object, None], ) -> List[torch.utils.hooks.RemovableHandle]: grad_out: Dict[device, Tensor] = {} @@ -864,10 +854,9 @@ def forward_hook( out: Union[Tensor, Tuple[Tensor, ...]], ) -> None: nonlocal grad_out - grad_out = {} - # pyre-fixme[53]: Captured variable `grad_out` is not annotated. def output_tensor_hook(output_grad: Tensor) -> None: + nonlocal grad_out grad_out[output_grad.device] = output_grad if isinstance(out, tuple): @@ -878,12 +867,12 @@ def output_tensor_hook(output_grad: Tensor) -> None: else: out.register_hook(output_tensor_hook) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def pre_hook(module, inp): - # pyre-fixme[53]: Captured variable `module` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - def input_tensor_hook(input_grad: Tensor): + def pre_hook(module: Module, inp: Union[Tensor, Tuple[Tensor, ...]]) -> Tensor: + def input_tensor_hook( + input_grad: Tensor, + ) -> Union[None, Tensor, Tuple[Tensor, ...]]: + nonlocal grad_out + if len(grad_out) == 0: return hook_out = hook(module, input_grad, grad_out[input_grad.device])