From b9e38b2a2e2be2b5fb31842fa409b95abcbccbc6 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Tue, 19 Dec 2023 05:21:18 -0500 Subject: [PATCH] chore(client): only import tokenizers when needed (#284) https://github.com/anthropics/anthropic-sdk-python/issues/280 --- src/anthropic/_client.py | 6 +++--- src/anthropic/_tokenizers.py | 21 ++++++++++++++------- tests/test_tokenizer.py | 12 ++++++++++++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/anthropic/_client.py b/src/anthropic/_client.py index be2ac5ea..b4e27c58 100644 --- a/src/anthropic/_client.py +++ b/src/anthropic/_client.py @@ -7,7 +7,6 @@ from typing_extensions import Self, override import httpx -from tokenizers import Tokenizer # type: ignore[import] from . import resources, _constants, _exceptions from ._qs import Querystring @@ -27,6 +26,7 @@ from ._streaming import Stream as Stream from ._streaming import AsyncStream as AsyncStream from ._exceptions import APIStatusError +from ._tokenizers import TokenizerType # type: ignore[import] from ._tokenizers import sync_get_tokenizer, async_get_tokenizer from ._base_client import ( DEFAULT_LIMITS, @@ -264,7 +264,7 @@ def count_tokens( encoded_text = tokenizer.encode(text) # type: ignore return len(encoded_text.ids) # type: ignore - def get_tokenizer(self) -> Tokenizer: + def get_tokenizer(self) -> TokenizerType: return sync_get_tokenizer() @override @@ -515,7 +515,7 @@ async def count_tokens( encoded_text = tokenizer.encode(text) # type: ignore return len(encoded_text.ids) # type: ignore - async def get_tokenizer(self) -> Tokenizer: + async def get_tokenizer(self) -> TokenizerType: return await async_get_tokenizer() @override diff --git a/src/anthropic/_tokenizers.py b/src/anthropic/_tokenizers.py index c2ac9208..e5fd95c3 100644 --- a/src/anthropic/_tokenizers.py +++ b/src/anthropic/_tokenizers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import cast +from typing import TYPE_CHECKING, cast from pathlib import Path from anyio import Path as AsyncPath @@ -8,24 +8,31 @@ # tokenizers is untyped, https://github.com/huggingface/tokenizers/issues/811 # note: this comment affects the entire file # pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false -from tokenizers import Tokenizer # type: ignore[import] +if TYPE_CHECKING: + # we only import this at the type-level as deferring the import + # avoids issues like this: https://github.com/anthropics/anthropic-sdk-python/issues/280 + from tokenizers import Tokenizer as TokenizerType # type: ignore[import] +else: + TokenizerType = None def _get_tokenizer_cache_path() -> Path: return Path(__file__).parent / "tokenizer.json" -_tokenizer: Tokenizer | None = None +_tokenizer: TokenizerType | None = None -def _load_tokenizer(raw: str) -> Tokenizer: +def _load_tokenizer(raw: str) -> TokenizerType: global _tokenizer - _tokenizer = cast(Tokenizer, Tokenizer.from_str(raw)) + from tokenizers import Tokenizer + + _tokenizer = cast(TokenizerType, Tokenizer.from_str(raw)) return _tokenizer -def sync_get_tokenizer() -> Tokenizer: +def sync_get_tokenizer() -> TokenizerType: if _tokenizer is not None: return _tokenizer @@ -34,7 +41,7 @@ def sync_get_tokenizer() -> Tokenizer: return _load_tokenizer(text) -async def async_get_tokenizer() -> Tokenizer: +async def async_get_tokenizer() -> TokenizerType: if _tokenizer is not None: return _tokenizer diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 87be4044..802494d5 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,3 +1,4 @@ +import sys import asyncio import threading import concurrent.futures @@ -20,6 +21,17 @@ def _sync_tokenizer_test() -> None: assert len(encoded_text.ids) == 2 # type: ignore +def test_tokenizers_is_not_imported() -> None: + # note: this test relies on being executed before any of the + # other tests but is a valuable test to avoid issues like this + # https://github.com/anthropics/anthropic-sdk-python/issues/280 + assert "tokenizers" not in sys.modules + + _sync_tokenizer_test() + + assert "tokenizers" in sys.modules + + def test_threading() -> None: failed = False