From 2e06689fc6de46740821a3c08db1ac9ba06a1254 Mon Sep 17 00:00:00 2001 From: swarajpande5 Date: Sun, 22 Dec 2024 14:06:31 +0530 Subject: [PATCH] Prettify and improve download bars during model pulls Signed-off-by: swarajpande5 --- ramalama/cli.py | 1 + ramalama/common.py | 251 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 197 insertions(+), 55 deletions(-) diff --git a/ramalama/cli.py b/ramalama/cli.py index 5694bf56..415c8ef9 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -215,6 +215,7 @@ def configure_arguments(parser): ) parser.add_argument("-v", "--version", dest="version", action="store_true", help="show RamaLama version") + def configure_subcommands(parser): """Add subcommand parsers to the main argument parser.""" subparsers = parser.add_subparsers(dest="subcommand") diff --git a/ramalama/common.py b/ramalama/common.py index a2563e8a..40973569 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -1,5 +1,6 @@ """ramalama common module.""" +import fcntl import hashlib import os import random @@ -7,7 +8,9 @@ import string import subprocess import sys +import time import urllib.request +import urllib.error x = False mnt_dir = "/mnt/models" @@ -152,65 +155,203 @@ def genname(): return "ramalama_" + "".join(random.choices(string.ascii_letters + string.digits, k=10)) +# The following code is inspired from: https://github.com/ericcurtin/lm-pull/blob/main/lm-pull.py + +class File: + def __init__(self): + self.file = None + self.fd = -1 + + def open(self, filename, mode): + self.file = open(filename, mode) + return self.file + + def lock(self): + if self.file: + self.fd = self.file.fileno() + try: + fcntl.flock(self.fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + self.fd = -1 + return 1 + + return 0 + + def __del__(self): + if self.fd >= 0: + fcntl.flock(self.fd, fcntl.LOCK_UN) + + if self.file: + self.file.close() + + +class HttpClient: + def __init__(self): + pass + + def init(self, url, headers, output_file, progress, response_str=None): + output_file_partial = None + if output_file: + output_file_partial = output_file + ".partial" + + self.file_size = self.set_resume_point(output_file_partial) + self.printed = False + if self.urlopen(url, headers): + return 1 + + self.total_to_download = int(self.response.getheader('content-length', 0)) + if response_str is not None: + response_str.append(self.response.read().decode('utf-8')) + else: + out = File() + if not out.open(output_file_partial, "ab"): + print("Failed to open file") + + return 1 + + if out.lock(): + print("Failed to exclusively lock file") + + return 1 + + self.now_downloaded = 0 + self.start_time = time.time() + self.perform_download(out.file, progress) + + if output_file: + os.rename(output_file_partial, output_file) + + if self.printed: + print("\n") + + return 0 + + def urlopen(self, url, headers): + headers["Range"] = f"bytes={self.file_size}-" + request = urllib.request.Request(url, headers=headers) + try: + self.response = urllib.request.urlopen(request) + except urllib.error.HTTPError as e: + print(f"Request failed: {e.code}", file=sys.stderr) + + return 1 + + if self.response.status not in (200, 206): + print(f"Request failed: {self.response.status}", file=sys.stderr) + + return 1 + + return 0 + + def perform_download(self, file, progress): + self.total_to_download += self.file_size + self.now_downloaded = 0 + self.start_time = time.time() + while True: + data = self.response.read(1024) + if not data: + break + + size = file.write(data) + if progress: + self.update_progress(size) + + def human_readable_time(self, seconds): + hrs = int(seconds) // 3600 + mins = (int(seconds) % 3600) // 60 + secs = int(seconds) % 60 + width = 10 + if hrs > 0: + return f"{hrs}h {mins:02}m {secs:02}s".rjust(width) + elif mins > 0: + return f"{mins}m {secs:02}s".rjust(width) + else: + return f"{secs}s".rjust(width) + + def human_readable_size(self, size): + width = 10 + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size < 1024: + return f"{size:.2f} {unit}".rjust(width) + + size /= 1024 + + return f"{size:.2f} PB".rjust(width) + + def get_terminal_width(self): + return shutil.get_terminal_size().columns + + def generate_progress_prefix(self, percentage): + return f"{percentage}% |".rjust(6) + + def generate_progress_suffix(self, now_downloaded_plus_file_size, speed, estimated_time): + return f"{self.human_readable_size(now_downloaded_plus_file_size)}/{self.human_readable_size(self.total_to_download)}{self.human_readable_size(speed)}/s{self.human_readable_time(estimated_time)}" # noqa: E501 + + def calculate_progress_bar_width(self, progress_prefix, progress_suffix): + progress_bar_width = self.get_terminal_width() - len(progress_prefix) - len(progress_suffix) - 3 + if progress_bar_width < 1: + progress_bar_width = 1 + + return progress_bar_width + + def generate_progress_bar(self, progress_bar_width, percentage): + pos = (percentage * progress_bar_width) // 100 + progress_bar = "" + for i in range(progress_bar_width): + progress_bar += "█" if i < pos else " " + + return progress_bar + + def set_resume_point(self, output_file): + if output_file and os.path.exists(output_file): + return os.path.getsize(output_file) + + return 0 + + def print_progress(self, progress_prefix, progress_bar, progress_suffix): + print(f"\r{progress_prefix}{progress_bar}| {progress_suffix}", end="") + + def update_progress(self, chunk_size): + self.now_downloaded += chunk_size + now_downloaded_plus_file_size = self.now_downloaded + self.file_size + percentage = (now_downloaded_plus_file_size * 100) // self.total_to_download if self.total_to_download else 100 + progress_prefix = self.generate_progress_prefix(percentage) + speed = self.calculate_speed(self.now_downloaded, self.start_time) + tim = (self.total_to_download - self.now_downloaded) // speed + progress_suffix = self.generate_progress_suffix(now_downloaded_plus_file_size, speed, tim) + progress_bar_width = self.calculate_progress_bar_width(progress_prefix, progress_suffix) + progress_bar = self.generate_progress_bar(progress_bar_width, percentage) + self.print_progress(progress_prefix, progress_bar, progress_suffix) + self.printed = True + + def calculate_speed(self, now_downloaded, start_time): + now = time.time() + elapsed_seconds = now - start_time + return now_downloaded / elapsed_seconds + + def download_file(url, dest_path, headers=None, show_progress=True): + """ + Downloads a file from a given URL to a specified destination path. + + Args: + url (str): The URL to download from. + dest_path (str): The path to save the downloaded file. + headers (dict): Optional headers to include in the request. + show_progress (bool): Whether to show a progress bar during download. + + Returns: + None + """ + http_client = HttpClient() + + headers = headers or {} + try: - from tqdm import tqdm - except FileNotFoundError: - raise NotImplementedError( - """\ -Ollama models requires the tqdm modules. -This module can be installed via PyPi tools like pip, pip3, pipx or via -distribution package managers like dnf or apt. Example: -pip install tqdm -""" - ) - - # Check if partially downloaded file exists - if os.path.exists(dest_path): - downloaded_size = os.path.getsize(dest_path) - else: - downloaded_size = 0 - - request = urllib.request.Request(url, headers=headers or {}) - request.headers["Range"] = f"bytes={downloaded_size}-" # Set range header - - filename = dest_path.split('/')[-1] - - bar_format = "Pulling {desc}: {percentage:3.0f}% ▕{bar:20}▏ {n_fmt}/{total_fmt} {rate_fmt} {remaining}" - try: - with urllib.request.urlopen(request) as response: - total_size = int(response.headers.get("Content-Length", 0)) + downloaded_size - chunk_size = 8192 # 8 KB chunks - - with open(dest_path, "ab") as file: - if show_progress: - with tqdm( - desc=filename, - total=total_size, - initial=downloaded_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - bar_format=bar_format, - ascii=True, - ) as progress_bar: - while True: - chunk = response.read(chunk_size) - if not chunk: - break - file.write(chunk) - progress_bar.update(len(chunk)) - else: - # Download file without showing progress - while True: - chunk = response.read(chunk_size) - if not chunk: - break - file.write(chunk) + http_client.init(url=url, headers=headers, output_file=dest_path, progress=show_progress) except urllib.error.HTTPError as e: - if e.code == 416: + if e.code == 416: # Range not satisfiable if show_progress: - # If we get a 416 error, it means the file is fully downloaded print(f"File {url} already fully downloaded.") else: raise e