Skip to content

Commit

Permalink
hijack HF's download so it works with gcs etc. (#819)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 24, 2024
1 parent 290ab80 commit a472cd4
Showing 1 changed file with 107 additions and 94 deletions.
201 changes: 107 additions & 94 deletions src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import contextlib
import dataclasses
import json
import logging
Expand All @@ -10,7 +11,6 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, Optional, Tuple, Type, TypeVar, Union, cast
from urllib.parse import urlparse

import draccus
import equinox as eqx
Expand All @@ -21,8 +21,9 @@
import mergedeep
import safetensors
import safetensors.numpy
import transformers.utils.hub
from huggingface_hub import HfApi, hf_hub_download, repo_exists, snapshot_download
from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError
from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError, RepositoryNotFoundError
from jax.experimental.multihost_utils import sync_global_devices
from jax.random import PRNGKey
from jaxtyping import Array
Expand Down Expand Up @@ -324,11 +325,8 @@ def _infer_config_class(hf_config_class, ref, trust_remote_code):
if ref is None:
raise ValueError("Must provide either config class or reference_checkpoint")
path, rev = ref.model_name_or_path, ref.revision
config = AutoConfig.from_pretrained(
path,
revision=rev,
trust_remote_code=trust_remote_code,
)
with _patch_hf_hub_download():
config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=trust_remote_code)
clss = type(config)
elif isinstance(hf_config_class, str):
if ref is None:
Expand Down Expand Up @@ -423,7 +421,9 @@ def config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) -

def hf_config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) -> HfConfig:
path, rev = self._get_ref(ref)
config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code)

with _patch_hf_hub_download():
config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code)
return config

def _get_ref(self, ref) -> Tuple[str, Optional[str]]:
Expand All @@ -450,49 +450,51 @@ def load_state_dict(self, ref: Optional[Union[str, RepoRef]] = None, dtype: Opti
except HFValidationError:
pass

# TODO: load models from gcs etc.
if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)):
state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype)
elif os.path.exists(os.path.join(id, PYTORCH_MODEL)):
state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype)
else:
try:
model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev)
state_dict = _load_safe_tensors(model_path, dtype)
except (EntryNotFoundError, HFValidationError):
model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev)
state_dict = _load_torch(model_path, dtype)
with _patch_hf_hub_download() as hf_hub_download:
# TODO: load models from gcs etc.
if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)):
state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype)
elif os.path.exists(os.path.join(id, PYTORCH_MODEL)):
state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype)
else:
try:
model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev)
state_dict = _load_safe_tensors(model_path, dtype)
except (EntryNotFoundError, HFValidationError):
model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev)
state_dict = _load_torch(model_path, dtype)

return state_dict
return state_dict

def _load_shards(self, id: str, index_file: str, rev: Optional[str], dtype) -> dict:
"""Load model from sharded files based on the provided index."""
index_path = os.path.join(id, index_file)
if not os.path.exists(index_path):
# Download the index file if not found locally
index_path = hf_hub_download(id, index_file, revision=rev)

with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)

shard_files = list(set(index["weight_map"].values()))
final_state_dict = {}

# right now we do safe tensors thing
# where we load into memory then update some dict
if "safetensors" in index_file:
loader = _load_safe_tensors
else:
loader = _load_torch
with _patch_hf_hub_download() as hf_hub_download:
index_path = os.path.join(id, index_file)
if not os.path.exists(index_path):
# Download the index file if not found locally
index_path = hf_hub_download(id, index_file, revision=rev)

with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)

shard_files = list(set(index["weight_map"].values()))
final_state_dict = {}

# right now we do safe tensors thing
# where we load into memory then update some dict
if "safetensors" in index_file:
loader = _load_safe_tensors
else:
loader = _load_torch

for shard_file in shard_files:
shard_path = os.path.join(id, shard_file)
if not os.path.exists(shard_path):
# Download the shard if not found locally
shard_path = hf_hub_download(id, shard_file, revision=rev)
for shard_file in shard_files:
shard_path = os.path.join(id, shard_file)
if not os.path.exists(shard_path):
# Download the shard if not found locally
shard_path = hf_hub_download(id, shard_file, revision=rev)

shard_state_dict = loader(shard_path, dtype)
final_state_dict.update(shard_state_dict)
shard_state_dict = loader(shard_path, dtype)
final_state_dict.update(shard_state_dict)

return final_state_dict

Expand Down Expand Up @@ -588,22 +590,6 @@ def load_from_state_dict(template, state_dict):
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
lev_model = load_from_state_dict(lev_model, state_dict)

# all_arrays: list[jax.Array] = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
# print(f"Total size of live arrays: {total_size / 1e9:.2f} GB")
# gc.collect() # sometimes takes a while to free buffers otherwise
# try:
# get_backend().defragment()
# except Exception as e:
# warnings.warn(f"Could not defragment because {e}")
# pass
# all_arrays = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
# print(f"Total size of live arrays: {total_size / 1e9:.2f} GB")
# all_arrays = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
# print(f"Total size of live arrays: {total_size / 1e9:.2f} GB")

return lev_model

def _save_pretrained_local(
Expand Down Expand Up @@ -874,45 +860,20 @@ def cb(step: StepInfo):
return cb


def arbitrary_load_from_hf(
model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True
) -> Union[HfTokenizer | ProcessorMixin]:
is_url_like = urlparse(model_name_or_path).scheme != ""
if is_url_like:
if revision is not None:
raise ValueError("revision is not supported for URLs")
# tokenizers are directories, so we have to copy them locally
if local_cache_dir is None:
local_cache_dir = tempfile.mkdtemp()

fs, path = fsspec.core.url_to_fs(model_name_or_path)
fs.get(path, local_cache_dir, recursive=True)
base_path = os.path.basename(path)
return from_pretrained_lambda(os.path.join(local_cache_dir, base_path), trust_remote_code=trust_remote_code)
else:
return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code)


def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer:
"""Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec"""
return arbitrary_load_from_hf(
model_name_or_path,
AutoTokenizer.from_pretrained,
revision=revision,
local_cache_dir=local_cache_dir,
trust_remote_code=trust_remote_code,
)
with _patch_hf_hub_download():
return AutoTokenizer.from_pretrained(
model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code
)


def load_processor(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> ProcessorMixin:
"""Like AutoProcessor.from_pretrained, but works with gs:// paths or anything on fsspec"""
return arbitrary_load_from_hf(
model_name_or_path,
AutoProcessor.from_pretrained,
revision=revision,
local_cache_dir=local_cache_dir,
trust_remote_code=trust_remote_code,
)
with _patch_hf_hub_download():
return AutoProcessor.from_pretrained(
model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code
)


_sync_count = 0
Expand Down Expand Up @@ -1111,3 +1072,55 @@ def _should_use_cpu_for_checkpoint_loading():
return False
if sum(accel_memory) < cpu_memory:
return True


def _is_hf_hub_model(ref: RepoRef):
api = HfApi()

try:
api.model_info(repo_id=ref.model_name_or_path)
return True
except RepositoryNotFoundError:
return False


@contextlib.contextmanager
def _patch_hf_hub_download():
"""
Temporarily monkeypatch `hf_hub_download` to handle fsspec URLs, ensuring the temporary directory
persists for the lifetime of the context manager.
"""
original_hf_hub_download = transformers.utils.hub.hf_hub_download

# Create a temporary directory that persists through the context manager
with tempfile.TemporaryDirectory() as tmpdir:

def custom_hf_hub_download(*args, **kwargs):
"""
Custom implementation of hf_hub_download to handle fsspec URLs.
"""
repo_id = kwargs.get("repo_id", args[0] if len(args) > 0 else None)
filename = kwargs.get("filename", args[1] if len(args) > 1 else None)

if repo_id and filename and _is_url_like(repo_id):
fs, path = fsspec.core.url_to_fs(repo_id)
remote_path = os.path.join(path, filename)
local_path = os.path.join(tmpdir, filename)

if not fs.exists(remote_path):
raise EntryNotFoundError(f"File {remote_path} not found")

fs.get(remote_path, local_path)
return local_path

# Fallback to the original implementation
return original_hf_hub_download(*args, **kwargs)

# Monkeypatch hf_hub_download
transformers.utils.hub.hf_hub_download = custom_hf_hub_download

try:
yield custom_hf_hub_download
finally:
# Restore the original implementation
transformers.utils.hub.hf_hub_download = original_hf_hub_download

0 comments on commit a472cd4

Please sign in to comment.