diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 0a9a42770..2470ae0c1 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -90,8 +90,8 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ... @typing.overload def _is_tuple( - inputs: TensorOrTupleOfTensorsGeneric, -) -> bool: ... # type: ignore + inputs: TensorOrTupleOfTensorsGeneric, # type: ignore +) -> bool: ... def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 538135003..da5ffd224 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -2,7 +2,18 @@ # pyre-strict -from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union +from collections import UserDict +from typing import ( + List, + Literal, + Optional, + overload, + Protocol, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) from torch import Tensor from torch.nn import Module @@ -30,6 +41,13 @@ ] +# Necessary for Python >=3.7 and <3.9! +if TYPE_CHECKING: + BatchEncodingType = UserDict[Union[int, str], object] +else: + BatchEncodingType = UserDict + + class TokenizerLike(Protocol): """A protocol for tokenizer-like objects that can be used with Captum LLM attribution methods.""" @@ -62,3 +80,9 @@ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ... def convert_tokens_to_ids( self, tokens: Union[List[str], str] ) -> Union[List[int], int]: ... + + def __call__( + self, + text: Optional[Union[str, List[str], List[List[str]]]] = None, + return_offsets_mapping: bool = False, + ) -> BatchEncodingType: ... diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 085813b09..10e4408eb 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -5,6 +5,7 @@ from typing import List, Literal, Optional, overload, Union import torch +from captum._utils.typing import BatchEncodingType from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized from tests.helpers import BaseTest @@ -68,6 +69,13 @@ def convert_tokens_to_ids( def decode(self, token_ids: Tensor) -> str: raise NotImplementedError + def __call__( + self, + text: Optional[Union[str, List[str], List[List[str]]]] = None, + return_offsets_mapping: bool = False, + ) -> BatchEncodingType: + raise NotImplementedError + class TestTextTemplateInput(BaseTest): @parameterized.expand( diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index d22bef384..b5b4cedfc 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -19,6 +19,7 @@ import torch from captum._utils.models.linear_model import SkLearnLasso +from captum._utils.typing import BatchEncodingType from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap @@ -96,6 +97,13 @@ def decode(self, token_ids: Tensor) -> str: # pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`. return tokens if isinstance(tokens, str) else " ".join(tokens) + def __call__( + self, + text: Optional[Union[str, List[str], List[List[str]]]] = None, + return_offsets_mapping: bool = False, + ) -> BatchEncodingType: + raise NotImplementedError + class Result(NamedTuple): logits: Tensor