Skip to content

Commit

Permalink
Download entire model upfront
Browse files Browse the repository at this point in the history
This change will download all model files initially, even if we only need the
tokenizer right now. This simplifies offline caching and eliminates a bug
related to missing model files.
  • Loading branch information
jncraton committed Aug 29, 2024
1 parent 9857154 commit a7fd9ae
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Changed

- Skip checking for model updates
- Download entire model upfront even if we only need the tokenizer initially

## 0.20 - 2024-04-25

Expand Down
14 changes: 9 additions & 5 deletions languagemodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ def get_model_info(model_type="instruct"):

def initialize_tokenizer(model_type, model_name):
model_info = get_model_info(model_type)
rev = model_info.get("revision", None)

tok_config = hf_hub_download(model_info["path"], "tokenizer.json")
tok_config = hf_hub_download(
model_info["path"], "tokenizer.json", revision=rev, local_files_only=True
)
tokenizer = Tokenizer.from_file(tok_config)

if model_type == "embedding":
Expand All @@ -44,7 +47,7 @@ def initialize_tokenizer(model_type, model_name):
return tokenizer


def initialize_model(model_type, model_name):
def initialize_model(model_type, model_name, tokenizer_only=False):
model_info = get_model_info(model_type)

allowed = ["*.bin", "*.txt", "*.json"]
Expand All @@ -67,6 +70,9 @@ def initialize_model(model_type, model_name):
model_info["path"], max_workers=1, allow_patterns=allowed, revision=rev
)

if tokenizer_only:
return None

if model_info["architecture"] == "encoder-only-transformer":
return ctranslate2.Encoder(
path,
Expand Down Expand Up @@ -111,10 +117,8 @@ def get_model(model_type, tokenizer_only=False):
pass

if model_name not in modelcache:
model = initialize_model(model_type, model_name, tokenizer_only)
tokenizer = initialize_tokenizer(model_type, model_name)
model = None
if not tokenizer_only:
model = initialize_model(model_type, model_name)
modelcache[model_name] = (tokenizer, model)
elif not tokenizer_only:
# Make sure model is loaded if we've never loaded it
Expand Down

0 comments on commit a7fd9ae

Please sign in to comment.