Skip to content

Commit

Permalink
chore(client): only import tokenizers when needed (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Dec 19, 2023
1 parent 3a23912 commit b9e38b2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/anthropic/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing_extensions import Self, override

import httpx
from tokenizers import Tokenizer # type: ignore[import]

from . import resources, _constants, _exceptions
from ._qs import Querystring
Expand All @@ -27,6 +26,7 @@
from ._streaming import Stream as Stream
from ._streaming import AsyncStream as AsyncStream
from ._exceptions import APIStatusError
from ._tokenizers import TokenizerType # type: ignore[import]
from ._tokenizers import sync_get_tokenizer, async_get_tokenizer
from ._base_client import (
DEFAULT_LIMITS,
Expand Down Expand Up @@ -264,7 +264,7 @@ def count_tokens(
encoded_text = tokenizer.encode(text) # type: ignore
return len(encoded_text.ids) # type: ignore

def get_tokenizer(self) -> Tokenizer:
def get_tokenizer(self) -> TokenizerType:
return sync_get_tokenizer()

@override
Expand Down Expand Up @@ -515,7 +515,7 @@ async def count_tokens(
encoded_text = tokenizer.encode(text) # type: ignore
return len(encoded_text.ids) # type: ignore

async def get_tokenizer(self) -> Tokenizer:
async def get_tokenizer(self) -> TokenizerType:
return await async_get_tokenizer()

@override
Expand Down
21 changes: 14 additions & 7 deletions src/anthropic/_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
from __future__ import annotations

from typing import cast
from typing import TYPE_CHECKING, cast
from pathlib import Path

from anyio import Path as AsyncPath

# tokenizers is untyped, https://github.com/huggingface/tokenizers/issues/811
# note: this comment affects the entire file
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
from tokenizers import Tokenizer # type: ignore[import]
if TYPE_CHECKING:
# we only import this at the type-level as deferring the import
# avoids issues like this: https://github.com/anthropics/anthropic-sdk-python/issues/280
from tokenizers import Tokenizer as TokenizerType # type: ignore[import]
else:
TokenizerType = None


def _get_tokenizer_cache_path() -> Path:
return Path(__file__).parent / "tokenizer.json"


_tokenizer: Tokenizer | None = None
_tokenizer: TokenizerType | None = None


def _load_tokenizer(raw: str) -> Tokenizer:
def _load_tokenizer(raw: str) -> TokenizerType:
global _tokenizer

_tokenizer = cast(Tokenizer, Tokenizer.from_str(raw))
from tokenizers import Tokenizer

_tokenizer = cast(TokenizerType, Tokenizer.from_str(raw))
return _tokenizer


def sync_get_tokenizer() -> Tokenizer:
def sync_get_tokenizer() -> TokenizerType:
if _tokenizer is not None:
return _tokenizer

Expand All @@ -34,7 +41,7 @@ def sync_get_tokenizer() -> Tokenizer:
return _load_tokenizer(text)


async def async_get_tokenizer() -> Tokenizer:
async def async_get_tokenizer() -> TokenizerType:
if _tokenizer is not None:
return _tokenizer

Expand Down
12 changes: 12 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import asyncio
import threading
import concurrent.futures
Expand All @@ -20,6 +21,17 @@ def _sync_tokenizer_test() -> None:
assert len(encoded_text.ids) == 2 # type: ignore


def test_tokenizers_is_not_imported() -> None:
# note: this test relies on being executed before any of the
# other tests but is a valuable test to avoid issues like this
# https://github.com/anthropics/anthropic-sdk-python/issues/280
assert "tokenizers" not in sys.modules

_sync_tokenizer_test()

assert "tokenizers" in sys.modules


def test_threading() -> None:
failed = False

Expand Down

0 comments on commit b9e38b2

Please sign in to comment.