Skip to content

Commit

Permalink
Merge pull request #288 from NexaAI/perry-debug
Browse files Browse the repository at this point in the history
Several general modifications
  • Loading branch information
zhiyuan8 authored Nov 25, 2024
2 parents 99022ea + 96cbc2b commit 38d988a
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 60 deletions.
87 changes: 67 additions & 20 deletions nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def default_use_processes():
def download_file_with_progress(
url: str,
file_path: Path,
chunk_size: int = 40 * 1024 * 1024,
chunk_size: int = 5 * 1024 * 1024,
max_workers: int = 20,
use_processes: bool = default_use_processes(),
**kwargs
Expand Down Expand Up @@ -396,12 +396,25 @@ def download_file_with_progress(
progress_bar.close()

if all(completed_chunks):
# Create a new progress bar for combining chunks
combine_progress = tqdm(
total=file_size,
unit='B',
unit_scale=True,
desc="Verifying download",
unit_divisor=1024
)

buffer_size = 1 * 1024 * 1024 # 1MB buffer

with open(file_path, "wb") as final_file:
for i in range(len(chunks)):
chunk_file = temp_dir / f"{file_path.name}.part{i}"
with open(chunk_file, "rb") as part_file:
final_file.write(part_file.read())

shutil.copyfileobj(part_file, final_file, buffer_size)
combine_progress.update(os.path.getsize(chunk_file))

combine_progress.close()
else:
raise Exception("Some chunks failed to download")

Expand Down Expand Up @@ -598,10 +611,11 @@ def is_model_exists(model_name):
# For AudioLM and Multimodal models, should check the file location instead of model name
if ":" in model_name:
model_path_with_slash = model_name.replace(":", "/")
model_path_with_backslash = model_name.replace(":", "\\")

# Check if model_prefix/model_suffix exists in any location path
# Check if model_prefix/model_suffix or model_prefix\model_suffix exists in any location path
for model_key, model_info in model_list.items():
if model_path_with_slash in model_info["location"]:
if model_path_with_slash in model_info["location"] or model_path_with_backslash in model_info["location"]:
return model_key

return model_name in model_list
Expand Down Expand Up @@ -648,10 +662,11 @@ def get_model_info(model_name):
# If not found and model_name contains ":", try path-based lookup
if ":" in model_name:
model_path_with_slash = model_name.replace(":", "/")
model_path_with_backslash = model_name.replace(":", "\\")

# Check if model_prefix/model_suffix exists in any location path
# Check if model_prefix/model_suffix or model_prefix\model_suffix exists in any location path
for model_key, model_info in model_list.items():
if model_path_with_slash in model_info["location"]:
if model_path_with_slash in model_info["location"] or model_path_with_backslash in model_info["location"]:
return model_info["location"], model_info["run_type"]

return None, None
Expand Down Expand Up @@ -701,35 +716,67 @@ def remove_model(model_path):
with open(NEXA_MODEL_LIST_PATH, "r") as f:
model_list = json.load(f)

# First try direct lookup
if model_path not in model_list:
print(f"Model {model_path} not found.")
return
# If not found and model_path contains ":", try path-based lookup
if ":" in model_path:
model_path_with_slash = model_path.replace(":", "/")
model_path_with_backslash = model_path.replace(":", "\\")

# Find matching model key
matching_key = None
for model_key, model_info in model_list.items():
if model_path_with_slash in model_info["location"] or model_path_with_backslash in model_info["location"]:
matching_key = model_key
break

if matching_key:
model_path = matching_key
else:
print(f"Model {model_path} not found.")
return
else:
print(f"Model {model_path} not found.")
return

model_info = model_list.pop(model_path)
model_location = model_info['location']
model_path = Path(model_location)

# Delete the model files
model_deleted = False
if model_path.is_file():
model_path.unlink()
print(f"Deleted model file: {model_path}")
model_deleted = True
elif model_path.is_dir():
shutil.rmtree(model_path)
print(f"Deleted model directory: {model_path}")
model_deleted = True
else:
print(f"Warning: Model location not found: {model_path}")

# Delete projectors
projector_keys = [k for k in model_list.keys() if 'projector' in k]
for key in projector_keys:
projector_info = model_list.pop(key)
projector_location = Path(projector_info['location'])
if projector_location.exists():
if projector_location.is_file():
projector_location.unlink()
else:
shutil.rmtree(projector_location)
print(f"Deleted projector: {projector_location}")
# Delete projectors only if model was successfully deleted
if model_deleted:
parent_dir = model_path.parent
gguf_files = list(parent_dir.glob("*.gguf"))

# Only proceed if there's exactly one .gguf file in the directory
if len(gguf_files) == 1:
projector_keys = [
k for k in model_list.keys()
if 'projector' in k and str(parent_dir) in model_list[k]['location']
]

for key in projector_keys:
projector_info = model_list.pop(key)
projector_location = Path(projector_info['location'])
if projector_location.exists():
if projector_location.is_file():
projector_location.unlink()
else:
shutil.rmtree(projector_location)
print(f"Deleted projector: {projector_location}")

# Update the model list file
with open(NEXA_MODEL_LIST_PATH, "w") as f:
Expand Down
79 changes: 59 additions & 20 deletions nexa/gguf/llama/_utils_spinner.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,79 @@
# For similar spinner animation implementation, refer to: nexa/utils.py

import sys
import threading
import time
import os
import itertools
from contextlib import contextmanager

def get_spinner_style(style="default"):
spinners = {
"default": '⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏'
"default": ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
}
return spinners.get(style, spinners["default"])

def spinning_cursor(style="default"):
while True:
for cursor in get_spinner_style(style):
yield cursor
def _get_output_stream():
"""Get the appropriate output stream based on platform."""
if sys.platform == "win32":
return open('CONOUT$', 'wb')
else:
try:
return os.open('/dev/tty', os.O_WRONLY)
except (FileNotFoundError, OSError):
return os.open('/dev/stdout', os.O_WRONLY)

def show_spinner(stop_event, style="default", message=""):
spinner = spinning_cursor(style)

fd = os.open('/dev/tty', os.O_WRONLY)
spinner = itertools.cycle(get_spinner_style(style))
fd = _get_output_stream()
is_windows = sys.platform == "win32"

while not stop_event.is_set():
display = f"\r{message} {next(spinner)}" if message else f"\r{next(spinner)}"
os.write(fd, display.encode())
time.sleep(0.1)

os.write(fd, b"\r" + b" " * (len(message) + 2))
os.write(fd, b"\r")
os.close(fd)
try:
while not stop_event.is_set():
display = f"\r{message} {next(spinner)}" if message else f"\r{next(spinner)} "

if is_windows:
fd.write(display.encode())
fd.flush()
else:
os.write(fd, display.encode())
time.sleep(0.1)

# Clear the spinner
clear_msg = b"\r" + b" " * (len(message) + 2) + b"\r"
if is_windows:
fd.write(clear_msg)
fd.flush()
else:
os.write(fd, clear_msg)

finally:
if is_windows:
fd.close()
else:
os.close(fd)

def start_spinner(style="default", message=""):
stop_event = threading.Event()
spinner_thread = threading.Thread(target=show_spinner, args=(stop_event, style, message))
spinner_thread.daemon = True
spinner_thread = threading.Thread(
target=show_spinner,
args=(stop_event, style, message),
daemon=True
)
spinner_thread.start()
return stop_event, spinner_thread

def stop_spinner(stop_event, spinner_thread):
stop_event.set()
spinner_thread.join()
if stop_event and not stop_event.is_set():
stop_event.set()
if spinner_thread and spinner_thread.is_alive():
spinner_thread.join()

@contextmanager
def spinning_cursor(message="", style="default"):
"""Context manager for spinner animation."""
stop_event, thread = start_spinner(style, message)
try:
yield
finally:
stop_spinner(stop_event, thread)
11 changes: 5 additions & 6 deletions nexa/gguf/nexa_inference_audio_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def run(self):
Run the audio language model inference loop.
"""
from nexa.gguf.llama._utils_spinner import start_spinner, stop_spinner

try:
while True:
audio_path = self._get_valid_audio_path()
user_input = nexa_prompt("Enter text (leave empty if no prompt): ")

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

try:
Expand All @@ -172,8 +172,6 @@ def run(self):
print("\nExiting...")
except Exception as e:
logging.error(f"\nError during audio generation: {e}", exc_info=True)
# finally:
# self.cleanup()

def _get_valid_audio_path(self) -> str:
"""
Expand All @@ -189,6 +187,7 @@ def _get_valid_audio_path(self) -> str:
else:
print(f"'{audio_path}' is not a valid audio path. Please try again.")

# @SpinningCursorAnimation()
def inference(self, audio_path: str, prompt: str = "") -> str:
"""
Perform a single inference with the audio language model.
Expand Down
8 changes: 4 additions & 4 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def run_txt2img(self):
)

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

try:
Expand Down Expand Up @@ -292,8 +292,8 @@ def run_img2img(self):
)

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

images = self.img2img(
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def run(self):
generation_start_time = time.time()

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

if self.chat_format:
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def run(self):
user_input = input("Enter text to generate audio: ")

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

audio_data = self.audio_generation(user_input)
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def run(self):
continue

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

output = self._chat(user_input, image_path)
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_vlm_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def run(self):
user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt()

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

response = self.inference(user_input, image_path)
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def run(self):
audio_path = nexa_prompt("Enter the path to your audio file: ")

stop_event, spinner_thread = start_spinner(
style="default",
message=""
style="default",
message=""
)

self._transcribe_audio(audio_path)
Expand Down

0 comments on commit 38d988a

Please sign in to comment.