Skip to content

Commit

Permalink
Merge pull request #375 from swarajpande5/hf-https-pulls
Browse files Browse the repository at this point in the history
Replace `huggingface-cli download` command with simple https client to pull models
  • Loading branch information
ericcurtin authored Oct 26, 2024
2 parents c5dbbe0 + 5749154 commit 566d7d0
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 147 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ build
ramalama/*.patch
dist
.#*
venv/
65 changes: 65 additions & 0 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import string
import subprocess
import sys
import urllib.request

x = False

Expand Down Expand Up @@ -154,3 +155,67 @@ def default_image():

def genname():
return "ramalama_" + "".join(random.choices(string.ascii_letters + string.digits, k=10))


def download_file(url, dest_path, headers=None, show_progress=True):
try:
from tqdm import tqdm
except FileNotFoundError:
raise NotImplementedError(
"""\
Ollama models requires the tqdm modules.
This module can be installed via PyPi tools like pip, pip3, pipx or via
distribution package managers like dnf or apt. Example:
pip install tqdm
"""
)

# Check if partially downloaded file exists
if os.path.exists(dest_path):
downloaded_size = os.path.getsize(dest_path)
else:
downloaded_size = 0

request = urllib.request.Request(url, headers=headers or {})
request.headers["Range"] = f"bytes={downloaded_size}-" # Set range header

filename = dest_path.split('/')[-1]

bar_format = "Pulling {desc}: {percentage:3.0f}% ▕{bar:20}▏ {n_fmt}/{total_fmt} {rate_fmt} {remaining}"
try:
with urllib.request.urlopen(request) as response:
total_size = int(response.headers.get("Content-Length", 0)) + downloaded_size
chunk_size = 8192 # 8 KB chunks

with open(dest_path, "ab") as file:
if show_progress:
with tqdm(
desc=filename,
total=total_size,
initial=downloaded_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
bar_format=bar_format,
ascii=True,
) as progress_bar:
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
progress_bar.update(len(chunk))
else:
# Download file without showing progress
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
except urllib.error.HTTPError as e:
if e.code == 416:
if show_progress:
# If we get a 416 error, it means the file is fully downloaded
print(f"File {url} already fully downloaded.")
else:
raise e
170 changes: 88 additions & 82 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,102 @@
import os
from ramalama.common import run_cmd, exec_cmd
import urllib.request
from ramalama.common import run_cmd, exec_cmd, download_file, verify_checksum
from ramalama.model import Model

missing_huggingface = """
Huggingface models requires the huggingface-cli and tqdm modules.
These modules can be installed via PyPi tools like pip, pip3, pipx or via
Optional: Huggingface models require the huggingface-cli and tqdm modules.
These modules can be installed via PyPi tools like pip, pip3, pipx, or via
distribution package managers like dnf or apt. Example:
pip install huggingface_hub tqdm
"""


def is_huggingface_cli_available():
"""Check if huggingface-cli is available on the system."""
try:
run_cmd(["huggingface-cli", "version"])
return True
except FileNotFoundError:
print("huggingface-cli not found. Some features may be limited.\n" + missing_huggingface)
return False


def fetch_checksum_from_api(url):
"""Fetch the SHA-256 checksum from the model's metadata API."""
with urllib.request.urlopen(url) as response:
data = response.read().decode()
# Extract the SHA-256 checksum from the `oid sha256` line
for line in data.splitlines():
if line.startswith("oid sha256:"):
return line.split(":", 1)[1].strip()
raise ValueError("SHA-256 checksum not found in the API response.")


class Huggingface(Model):
def __init__(self, model):
model = model.removeprefix("huggingface://")
model = model.removeprefix("hf://")
super().__init__(model)
self.type = "HuggingFace"
split = self.model.rsplit("/", 1)
self.directory = ""
if len(split) > 1:
self.directory = split[0]
self.filename = split[1]
else:
self.filename = split[0]
self.directory = split[0] if len(split) > 1 else ""
self.filename = split[1] if len(split) > 1 else split[0]
self.hf_cli_available = is_huggingface_cli_available()

def login(self, args):
if not self.hf_cli_available:
print("huggingface-cli not available, skipping login.")
return
conman_args = ["huggingface-cli", "login"]
if args.token:
conman_args.extend(["--token", args.token])
try:
self.exec(conman_args)
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
self.exec(conman_args)

def logout(self, args):
if not self.hf_cli_available:
print("huggingface-cli not available, skipping logout.")
return
conman_args = ["huggingface-cli", "logout"]
if args.token:
conman_args.extend(["--token", args.token])
conman_args.extend(args)
self.exec(conman_args)

def path(self, args):
return self.symlink_path(args)

def pull(self, args):
relative_target_path = ""
symlink_path = self.symlink_path(args)
directory_path = os.path.join(args.store, "repos", "huggingface", self.directory, self.filename)
os.makedirs(directory_path, exist_ok=True)

symlink_dir = os.path.dirname(symlink_path)
os.makedirs(symlink_dir, exist_ok=True)

# Fetch the SHA-256 checksum from the API
checksum_api_url = f"https://huggingface.co/{self.directory}/raw/main/{self.filename}"
sha256_checksum = fetch_checksum_from_api(checksum_api_url)

gguf_path = self.download(args.store)
relative_target_path = os.path.relpath(gguf_path.rstrip(), start=os.path.dirname(symlink_path))
directory = f"{args.store}/models/huggingface/{self.directory}"
os.makedirs(directory, exist_ok=True)
target_path = os.path.join(directory_path, f"sha256:{sha256_checksum}")

if os.path.exists(target_path) and verify_checksum(target_path):
relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path))
if not self.check_valid_symlink_path(relative_target_path, symlink_path):
run_cmd(["ln", "-sf", relative_target_path, symlink_path], debug=args.debug)
return symlink_path

if os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path:
# Download the model file to the target path
url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}"
download_file(url, target_path, headers={}, show_progress=True)

if not verify_checksum(target_path):
print(f"Checksum mismatch for {target_path}, retrying download...")
os.remove(target_path)
download_file(url, target_path, headers={}, show_progress=True)
if not verify_checksum(target_path):
raise ValueError(f"Checksum verification failed for {target_path}")

relative_target_path = os.path.relpath(target_path, start=os.path.dirname(symlink_path))
if self.check_valid_symlink_path(relative_target_path, symlink_path):
# Symlink is already correct, no need to update it
return symlink_path

Expand All @@ -66,67 +105,34 @@ def pull(self, args):
return symlink_path

def push(self, source, args):
try:
proc = run_cmd(
[
"huggingface-cli",
"upload",
"--repo-type",
"model",
self.directory,
self.filename,
"--cache-dir",
args.store + "/repos/huggingface/.cache",
"--local-dir",
args.store + "/repos/huggingface/" + self.directory,
],
debug=args.debug,
)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
if not self.hf_cli_available:
print("huggingface-cli not available, skipping push.")
return
proc = run_cmd(
[
"huggingface-cli",
"upload",
"--repo-type",
"model",
self.directory,
self.filename,
"--cache-dir",
os.path.join(args.store, "repos", "huggingface", ".cache"),
"--local-dir",
os.path.join(args.store, "repos", "huggingface", self.directory),
],
debug=args.debug,
)
return proc.stdout.decode("utf-8")

def symlink_path(self, args):
return f"{args.store}/models/huggingface/{self.directory}/{self.filename}"
return os.path.join(args.store, "models", "huggingface", self.directory, self.filename)

def check_valid_symlink_path(self, relative_target_path, symlink_path):
return os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path

def exec(self, args):
try:
exec_cmd(args, args.debug)
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s
"""
% str(e).strip("'"),
missing_huggingface,
)

def download(self, store):
try:
proc = run_cmd(
[
"huggingface-cli",
"download",
self.directory,
self.filename,
"--cache-dir",
store + "/repos/huggingface/.cache",
"--local-dir",
store + "/repos/huggingface/" + self.directory,
]
)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
print(f"{str(e).strip()}\n{missing_huggingface}")
2 changes: 1 addition & 1 deletion ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def garbage_collection(self, args):
file_has_a_symlink = False
for file in files:
file_path = os.path.join(root, file)
if (repo == "ollama" and file.startswith("sha256:")) or file.endswith(".gguf"):
if file.startswith("sha256:") or file.endswith(".gguf"):
file_path = os.path.join(root, file)
for model_root, model_dirs, model_files in os.walk(model_dir):
for model_file in model_files:
Expand Down
65 changes: 1 addition & 64 deletions ramalama/ollama.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,9 @@
import os
import urllib.request
import json
from ramalama.common import run_cmd, verify_checksum
from ramalama.common import run_cmd, verify_checksum, download_file
from ramalama.model import Model

bar_format = "Pulling {desc}: {percentage:3.0f}% ▕{bar:20}▏ {n_fmt}/{total_fmt} {rate_fmt} {remaining}"


def download_file(url, dest_path, headers=None, show_progress=True):
try:
from tqdm import tqdm
except FileNotFoundError:
raise NotImplementedError(
"""\
Ollama models requires the tqdm modules.
This module can be installed via PyPi tools like pip, pip3, pipx or via
distribution package managers like dnf or apt. Example:
pip install tqdm
"""
)

# Check if partially downloaded file exists
if os.path.exists(dest_path):
downloaded_size = os.path.getsize(dest_path)
else:
downloaded_size = 0

request = urllib.request.Request(url, headers=headers or {})
request.headers["Range"] = f"bytes={downloaded_size}-" # Set range header

try:
with urllib.request.urlopen(request) as response:
total_size = int(response.headers.get("Content-Length", 0)) + downloaded_size
chunk_size = 8192 # 8 KB chunks

with open(dest_path, "ab") as file:
if show_progress:
with tqdm(
desc=dest_path[-16:],
total=total_size,
initial=downloaded_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
bar_format=bar_format,
ascii=True,
) as progress_bar:
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
progress_bar.update(len(chunk))
else:
# Download file without showing progress
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
except urllib.error.HTTPError as e:
if e.code == 416:
if show_progress:
# If we get a 416 error, it means the file is fully downloaded
print(f"File {url} already fully downloaded.")
else:
raise e


def fetch_manifest_data(registry_head, model_tag, accept):
url = f"{registry_head}/manifests/{model_tag}"
Expand Down

0 comments on commit 566d7d0

Please sign in to comment.