Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prettify and improve download bars during model pulls #521

Merged
merged 1 commit into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def configure_arguments(parser):
)
parser.add_argument("-v", "--version", dest="version", action="store_true", help="show RamaLama version")


def configure_subcommands(parser):
"""Add subcommand parsers to the main argument parser."""
subparsers = parser.add_subparsers(dest="subcommand")
Expand Down
251 changes: 196 additions & 55 deletions ramalama/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""ramalama common module."""

import fcntl
import hashlib
import os
import random
import shutil
import string
import subprocess
import sys
import time
import urllib.request
import urllib.error

x = False
mnt_dir = "/mnt/models"
Expand Down Expand Up @@ -152,65 +155,203 @@ def genname():
return "ramalama_" + "".join(random.choices(string.ascii_letters + string.digits, k=10))


# The following code is inspired from: https://github.com/ericcurtin/lm-pull/blob/main/lm-pull.py

class File:
def __init__(self):
self.file = None
self.fd = -1

def open(self, filename, mode):
self.file = open(filename, mode)
return self.file

def lock(self):
if self.file:
self.fd = self.file.fileno()
try:
fcntl.flock(self.fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
except BlockingIOError:
self.fd = -1
return 1

return 0

def __del__(self):
if self.fd >= 0:
fcntl.flock(self.fd, fcntl.LOCK_UN)

if self.file:
self.file.close()


class HttpClient:
def __init__(self):
pass

def init(self, url, headers, output_file, progress, response_str=None):
output_file_partial = None
if output_file:
output_file_partial = output_file + ".partial"

self.file_size = self.set_resume_point(output_file_partial)
self.printed = False
if self.urlopen(url, headers):
return 1

self.total_to_download = int(self.response.getheader('content-length', 0))
if response_str is not None:
response_str.append(self.response.read().decode('utf-8'))
else:
out = File()
if not out.open(output_file_partial, "ab"):
print("Failed to open file")

return 1

if out.lock():
print("Failed to exclusively lock file")

return 1

self.now_downloaded = 0
self.start_time = time.time()
self.perform_download(out.file, progress)

if output_file:
os.rename(output_file_partial, output_file)

if self.printed:
print("\n")

return 0

def urlopen(self, url, headers):
headers["Range"] = f"bytes={self.file_size}-"
request = urllib.request.Request(url, headers=headers)
try:
self.response = urllib.request.urlopen(request)
except urllib.error.HTTPError as e:
print(f"Request failed: {e.code}", file=sys.stderr)

return 1

if self.response.status not in (200, 206):
print(f"Request failed: {self.response.status}", file=sys.stderr)

return 1

return 0

def perform_download(self, file, progress):
self.total_to_download += self.file_size
self.now_downloaded = 0
self.start_time = time.time()
while True:
data = self.response.read(1024)
if not data:
break

size = file.write(data)
if progress:
self.update_progress(size)

def human_readable_time(self, seconds):
hrs = int(seconds) // 3600
mins = (int(seconds) % 3600) // 60
secs = int(seconds) % 60
width = 10
if hrs > 0:
return f"{hrs}h {mins:02}m {secs:02}s".rjust(width)
elif mins > 0:
return f"{mins}m {secs:02}s".rjust(width)
else:
return f"{secs}s".rjust(width)

def human_readable_size(self, size):
width = 10
for unit in ["B", "KB", "MB", "GB", "TB"]:
if size < 1024:
return f"{size:.2f} {unit}".rjust(width)

size /= 1024

return f"{size:.2f} PB".rjust(width)

def get_terminal_width(self):
return shutil.get_terminal_size().columns

def generate_progress_prefix(self, percentage):
return f"{percentage}% |".rjust(6)

def generate_progress_suffix(self, now_downloaded_plus_file_size, speed, estimated_time):
return f"{self.human_readable_size(now_downloaded_plus_file_size)}/{self.human_readable_size(self.total_to_download)}{self.human_readable_size(speed)}/s{self.human_readable_time(estimated_time)}" # noqa: E501

def calculate_progress_bar_width(self, progress_prefix, progress_suffix):
progress_bar_width = self.get_terminal_width() - len(progress_prefix) - len(progress_suffix) - 3
if progress_bar_width < 1:
progress_bar_width = 1

return progress_bar_width

def generate_progress_bar(self, progress_bar_width, percentage):
pos = (percentage * progress_bar_width) // 100
progress_bar = ""
for i in range(progress_bar_width):
progress_bar += "█" if i < pos else " "

return progress_bar

def set_resume_point(self, output_file):
if output_file and os.path.exists(output_file):
return os.path.getsize(output_file)

return 0

def print_progress(self, progress_prefix, progress_bar, progress_suffix):
print(f"\r{progress_prefix}{progress_bar}| {progress_suffix}", end="")

def update_progress(self, chunk_size):
self.now_downloaded += chunk_size
now_downloaded_plus_file_size = self.now_downloaded + self.file_size
percentage = (now_downloaded_plus_file_size * 100) // self.total_to_download if self.total_to_download else 100
progress_prefix = self.generate_progress_prefix(percentage)
speed = self.calculate_speed(self.now_downloaded, self.start_time)
tim = (self.total_to_download - self.now_downloaded) // speed
progress_suffix = self.generate_progress_suffix(now_downloaded_plus_file_size, speed, tim)
progress_bar_width = self.calculate_progress_bar_width(progress_prefix, progress_suffix)
progress_bar = self.generate_progress_bar(progress_bar_width, percentage)
self.print_progress(progress_prefix, progress_bar, progress_suffix)
self.printed = True

def calculate_speed(self, now_downloaded, start_time):
now = time.time()
elapsed_seconds = now - start_time
return now_downloaded / elapsed_seconds


def download_file(url, dest_path, headers=None, show_progress=True):
"""
Downloads a file from a given URL to a specified destination path.

Args:
url (str): The URL to download from.
dest_path (str): The path to save the downloaded file.
headers (dict): Optional headers to include in the request.
show_progress (bool): Whether to show a progress bar during download.

Returns:
None
"""
http_client = HttpClient()

headers = headers or {}

try:
from tqdm import tqdm
except FileNotFoundError:
raise NotImplementedError(
"""\
Ollama models requires the tqdm modules.
This module can be installed via PyPi tools like pip, pip3, pipx or via
distribution package managers like dnf or apt. Example:
pip install tqdm
"""
)

# Check if partially downloaded file exists
if os.path.exists(dest_path):
downloaded_size = os.path.getsize(dest_path)
else:
downloaded_size = 0

request = urllib.request.Request(url, headers=headers or {})
request.headers["Range"] = f"bytes={downloaded_size}-" # Set range header

filename = dest_path.split('/')[-1]

bar_format = "Pulling {desc}: {percentage:3.0f}% ▕{bar:20}▏ {n_fmt}/{total_fmt} {rate_fmt} {remaining}"
try:
with urllib.request.urlopen(request) as response:
total_size = int(response.headers.get("Content-Length", 0)) + downloaded_size
chunk_size = 8192 # 8 KB chunks

with open(dest_path, "ab") as file:
if show_progress:
with tqdm(
desc=filename,
total=total_size,
initial=downloaded_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
bar_format=bar_format,
ascii=True,
) as progress_bar:
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
progress_bar.update(len(chunk))
else:
# Download file without showing progress
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file.write(chunk)
http_client.init(url=url, headers=headers, output_file=dest_path, progress=show_progress)
except urllib.error.HTTPError as e:
if e.code == 416:
if e.code == 416: # Range not satisfiable
if show_progress:
# If we get a 416 error, it means the file is fully downloaded
print(f"File {url} already fully downloaded.")
else:
raise e
Expand Down
Loading