Skip to content

Commit

Permalink
Add __call__ to TokenizerLike (#1418)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1418

Add __call__ to TokenizerLike for transformers compatibility

Reviewed By: vivekmig

Differential Revision: D64998805

fbshipit-source-id: ac9fe813267f21ccd7bc9207fa1e951d66c33ac0
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 28, 2024
1 parent 8437daf commit 2c1c281
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
4 changes: 2 additions & 2 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ...

@typing.overload
def _is_tuple(
inputs: TensorOrTupleOfTensorsGeneric,
) -> bool: ... # type: ignore
inputs: TensorOrTupleOfTensorsGeneric, # type: ignore
) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
Expand Down
26 changes: 25 additions & 1 deletion captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

# pyre-strict

from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union
from collections import UserDict
from typing import (
List,
Literal,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -30,6 +41,13 @@
]


# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
BatchEncodingType = UserDict[Union[int, str], object]
else:
BatchEncodingType = UserDict


class TokenizerLike(Protocol):
"""A protocol for tokenizer-like objects that can be used with Captum
LLM attribution methods."""
Expand Down Expand Up @@ -62,3 +80,9 @@ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]: ...

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
return_offsets_mapping: bool = False,
) -> BatchEncodingType: ...
8 changes: 8 additions & 0 deletions tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Literal, Optional, overload, Union

import torch
from captum._utils.typing import BatchEncodingType
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized
from tests.helpers import BaseTest
Expand Down Expand Up @@ -68,6 +69,13 @@ def convert_tokens_to_ids(
def decode(self, token_ids: Tensor) -> str:
raise NotImplementedError

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
return_offsets_mapping: bool = False,
) -> BatchEncodingType:
raise NotImplementedError


class TestTextTemplateInput(BaseTest):
@parameterized.expand(
Expand Down
8 changes: 8 additions & 0 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.typing import BatchEncodingType
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
Expand Down Expand Up @@ -96,6 +97,13 @@ def decode(self, token_ids: Tensor) -> str:
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
return tokens if isinstance(tokens, str) else " ".join(tokens)

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
return_offsets_mapping: bool = False,
) -> BatchEncodingType:
raise NotImplementedError


class Result(NamedTuple):
logits: Tensor
Expand Down

0 comments on commit 2c1c281

Please sign in to comment.