Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hijack HF's download so it works with gcs etc. #819

Merged
merged 1 commit into from
Nov 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 107 additions & 94 deletions src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import contextlib
import dataclasses
import json
import logging
Expand All @@ -10,7 +11,6 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, Optional, Tuple, Type, TypeVar, Union, cast
from urllib.parse import urlparse

import draccus
import equinox as eqx
Expand All @@ -21,8 +21,9 @@
import mergedeep
import safetensors
import safetensors.numpy
import transformers.utils.hub
from huggingface_hub import HfApi, hf_hub_download, repo_exists, snapshot_download
from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError
from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError, RepositoryNotFoundError
from jax.experimental.multihost_utils import sync_global_devices
from jax.random import PRNGKey
from jaxtyping import Array
Expand Down Expand Up @@ -324,11 +325,8 @@ def _infer_config_class(hf_config_class, ref, trust_remote_code):
if ref is None:
raise ValueError("Must provide either config class or reference_checkpoint")
path, rev = ref.model_name_or_path, ref.revision
config = AutoConfig.from_pretrained(
path,
revision=rev,
trust_remote_code=trust_remote_code,
)
with _patch_hf_hub_download():
config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=trust_remote_code)
clss = type(config)
elif isinstance(hf_config_class, str):
if ref is None:
Expand Down Expand Up @@ -423,7 +421,9 @@ def config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) -

def hf_config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) -> HfConfig:
path, rev = self._get_ref(ref)
config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code)

with _patch_hf_hub_download():
config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code)
return config

def _get_ref(self, ref) -> Tuple[str, Optional[str]]:
Expand All @@ -450,49 +450,51 @@ def load_state_dict(self, ref: Optional[Union[str, RepoRef]] = None, dtype: Opti
except HFValidationError:
pass

# TODO: load models from gcs etc.
if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)):
state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype)
elif os.path.exists(os.path.join(id, PYTORCH_MODEL)):
state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype)
else:
try:
model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev)
state_dict = _load_safe_tensors(model_path, dtype)
except (EntryNotFoundError, HFValidationError):
model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev)
state_dict = _load_torch(model_path, dtype)
with _patch_hf_hub_download() as hf_hub_download:
# TODO: load models from gcs etc.
if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)):
state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype)
elif os.path.exists(os.path.join(id, PYTORCH_MODEL)):
state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype)
else:
try:
model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev)
state_dict = _load_safe_tensors(model_path, dtype)
except (EntryNotFoundError, HFValidationError):
model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev)
state_dict = _load_torch(model_path, dtype)

return state_dict
return state_dict

def _load_shards(self, id: str, index_file: str, rev: Optional[str], dtype) -> dict:
"""Load model from sharded files based on the provided index."""
index_path = os.path.join(id, index_file)
if not os.path.exists(index_path):
# Download the index file if not found locally
index_path = hf_hub_download(id, index_file, revision=rev)

with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)

shard_files = list(set(index["weight_map"].values()))
final_state_dict = {}

# right now we do safe tensors thing
# where we load into memory then update some dict
if "safetensors" in index_file:
loader = _load_safe_tensors
else:
loader = _load_torch
with _patch_hf_hub_download() as hf_hub_download:
index_path = os.path.join(id, index_file)
if not os.path.exists(index_path):
# Download the index file if not found locally
index_path = hf_hub_download(id, index_file, revision=rev)

with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)

shard_files = list(set(index["weight_map"].values()))
final_state_dict = {}

# right now we do safe tensors thing
# where we load into memory then update some dict
if "safetensors" in index_file:
loader = _load_safe_tensors
else:
loader = _load_torch

for shard_file in shard_files:
shard_path = os.path.join(id, shard_file)
if not os.path.exists(shard_path):
# Download the shard if not found locally
shard_path = hf_hub_download(id, shard_file, revision=rev)
for shard_file in shard_files:
shard_path = os.path.join(id, shard_file)
if not os.path.exists(shard_path):
# Download the shard if not found locally
shard_path = hf_hub_download(id, shard_file, revision=rev)

shard_state_dict = loader(shard_path, dtype)
final_state_dict.update(shard_state_dict)
shard_state_dict = loader(shard_path, dtype)
final_state_dict.update(shard_state_dict)

return final_state_dict

Expand Down Expand Up @@ -588,22 +590,6 @@ def load_from_state_dict(template, state_dict):
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
lev_model = load_from_state_dict(lev_model, state_dict)

# all_arrays: list[jax.Array] = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
# print(f"Total size of live arrays: {total_size / 1e9:.2f} GB")
# gc.collect() # sometimes takes a while to free buffers otherwise
# try:
# get_backend().defragment()
# except Exception as e:
# warnings.warn(f"Could not defragment because {e}")
# pass
# all_arrays = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
# print(f"Total size of live arrays: {total_size / 1e9:.2f} GB")
# all_arrays = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
# print(f"Total size of live arrays: {total_size / 1e9:.2f} GB")

return lev_model

def _save_pretrained_local(
Expand Down Expand Up @@ -874,45 +860,20 @@ def cb(step: StepInfo):
return cb


def arbitrary_load_from_hf(
model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True
) -> Union[HfTokenizer | ProcessorMixin]:
is_url_like = urlparse(model_name_or_path).scheme != ""
if is_url_like:
if revision is not None:
raise ValueError("revision is not supported for URLs")
# tokenizers are directories, so we have to copy them locally
if local_cache_dir is None:
local_cache_dir = tempfile.mkdtemp()

fs, path = fsspec.core.url_to_fs(model_name_or_path)
fs.get(path, local_cache_dir, recursive=True)
base_path = os.path.basename(path)
return from_pretrained_lambda(os.path.join(local_cache_dir, base_path), trust_remote_code=trust_remote_code)
else:
return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code)


def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer:
"""Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec"""
return arbitrary_load_from_hf(
model_name_or_path,
AutoTokenizer.from_pretrained,
revision=revision,
local_cache_dir=local_cache_dir,
trust_remote_code=trust_remote_code,
)
with _patch_hf_hub_download():
return AutoTokenizer.from_pretrained(
model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code
)


def load_processor(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> ProcessorMixin:
"""Like AutoProcessor.from_pretrained, but works with gs:// paths or anything on fsspec"""
return arbitrary_load_from_hf(
model_name_or_path,
AutoProcessor.from_pretrained,
revision=revision,
local_cache_dir=local_cache_dir,
trust_remote_code=trust_remote_code,
)
with _patch_hf_hub_download():
return AutoProcessor.from_pretrained(
model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code
)


_sync_count = 0
Expand Down Expand Up @@ -1111,3 +1072,55 @@ def _should_use_cpu_for_checkpoint_loading():
return False
if sum(accel_memory) < cpu_memory:
return True


def _is_hf_hub_model(ref: RepoRef):
api = HfApi()

try:
api.model_info(repo_id=ref.model_name_or_path)
return True
except RepositoryNotFoundError:
return False


@contextlib.contextmanager
def _patch_hf_hub_download():
"""
Temporarily monkeypatch `hf_hub_download` to handle fsspec URLs, ensuring the temporary directory
persists for the lifetime of the context manager.
"""
original_hf_hub_download = transformers.utils.hub.hf_hub_download

# Create a temporary directory that persists through the context manager
with tempfile.TemporaryDirectory() as tmpdir:

def custom_hf_hub_download(*args, **kwargs):
"""
Custom implementation of hf_hub_download to handle fsspec URLs.
"""
repo_id = kwargs.get("repo_id", args[0] if len(args) > 0 else None)
filename = kwargs.get("filename", args[1] if len(args) > 1 else None)

if repo_id and filename and _is_url_like(repo_id):
fs, path = fsspec.core.url_to_fs(repo_id)
remote_path = os.path.join(path, filename)
local_path = os.path.join(tmpdir, filename)

if not fs.exists(remote_path):
raise EntryNotFoundError(f"File {remote_path} not found")

fs.get(remote_path, local_path)
return local_path

# Fallback to the original implementation
return original_hf_hub_download(*args, **kwargs)

# Monkeypatch hf_hub_download
transformers.utils.hub.hf_hub_download = custom_hf_hub_download

try:
yield custom_hf_hub_download
finally:
# Restore the original implementation
transformers.utils.hub.hf_hub_download = original_hf_hub_download
Loading