Skip to content

Commit

Permalink
Merge pull request #325 from NexaAI/localui-group
Browse files Browse the repository at this point in the history
Localui group
  • Loading branch information
zhiyuan8 authored Dec 20, 2024
2 parents bdee294 + f38922c commit 20ab712
Show file tree
Hide file tree
Showing 14 changed files with 1,176 additions and 178 deletions.
24 changes: 21 additions & 3 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def run_ggml_server(args):
from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer

kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
model_path = kwargs.pop("model_path", None)
is_local_path = kwargs.pop("local_path", False)
model_type = kwargs.pop("model_type", None)
hf = kwargs.pop('huggingface', False)
Expand Down Expand Up @@ -272,6 +272,15 @@ def run_eval_tasks(args):
print("Please run: pip install 'nexaai[eval]'")
return

def run_siglip_server(args):
from nexa.siglip.nexa_siglip_server import run_nexa_ai_siglip_service
run_nexa_ai_siglip_service(
image_dir=args.image_dir,
host=args.host,
port=args.port,
reload=args.reload
)

def run_embedding_generation(args):
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
Expand Down Expand Up @@ -556,8 +565,8 @@ def main():
quantization_parser.add_argument("--keep_split", action="store_true", help="Quantize to the same number of shards")

# GGML server parser
server_parser = subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service")
server_parser.add_argument("model_path", type=str, nargs='?', help="Path or identifier for the model in Nexa Model Hub")
server_parser = subparsers.add_parser("server", help="Run the Nexa AI local service")
server_parser.add_argument("--model_path", type=str, help="Path or identifier for the model in Nexa Model Hub")
server_parser.add_argument("-lp", "--local_path", action="store_true", help="Indicate that the model path provided is the local path")
server_parser.add_argument("-mt", "--model_type", type=str, choices=[e.name for e in ModelType], help="Indicate the model running type, must be used with -lp, -hf or -ms")
server_parser.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub")
Expand Down Expand Up @@ -599,6 +608,13 @@ def main():
perf_eval_group.add_argument("--device", type=str, help="Device to run performance evaluation on, choose from 'cpu', 'cuda', 'mps'", default="cpu")
perf_eval_group.add_argument("--new_tokens", type=int, help="Number of new tokens to evaluate", default=100)

# Siglip Server
siglip_parser = subparsers.add_parser("siglip", help="Run the Nexa AI SigLIP Service")
siglip_parser.add_argument("--image_dir", type=str, help="Directory of images to load")
siglip_parser.add_argument("--host", type=str, default="localhost", help="Host to bind the server to")
siglip_parser.add_argument("--port", type=int, default=8100, help="Port to bind the server to")
siglip_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes")

args = parser.parse_args()

if args.command == "run":
Expand Down Expand Up @@ -627,6 +643,8 @@ def main():
run_onnx_inference(args)
elif args.command == "eval":
run_eval_tasks(args)
elif args.command == "siglip":
run_siglip_server(args)
elif args.command == "embed":
run_embedding_generation(args)
elif args.command == "pull":
Expand Down
2 changes: 2 additions & 0 deletions nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ModelType(Enum):
"phi2": "Phi-2:q4_0",
"phi3": "Phi-3-mini-128k-instruct:q4_0",
"phi3.5": "Phi-3.5-mini-instruct:q4_0",
"phi4": "Phi:q4_0",
"llama2-uncensored": "Llama2-7b-chat-uncensored:q4_0",
"llama3-uncensored": "Llama3-8B-Lexi-Uncensored:q4_K_M",
"openelm": "OpenELM-3B:q4_K_M",
Expand Down Expand Up @@ -413,6 +414,7 @@ class ModelType(Enum):
"Phi-3-mini-128k-instruct": ModelType.NLP,
"Phi-3-mini-4k-instruct": ModelType.NLP,
"Phi-3.5-mini-instruct": ModelType.NLP,
"Phi-4": ModelType.NLP,
"CodeQwen1.5-7B-Instruct": ModelType.NLP,
"Qwen2-0.5B-Instruct": ModelType.NLP,
"Qwen2-1.5B-Instruct": ModelType.NLP,
Expand Down
1 change: 0 additions & 1 deletion nexa/eval/nexa_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from nexa.eval import evaluator
from nexa.eval.nexa_task.task_manager import TaskManager
from nexa.eval.utils import make_table, handle_non_serializable
from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer
from nexa.constants import NEXA_MODEL_EVAL_RESULTS_PATH, NEXA_RUN_MODEL_MAP
from nexa.eval.nexa_perf import (
Benchmark,
Expand Down
24 changes: 24 additions & 0 deletions nexa/gguf/llama/audio_lm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ def process_full(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bo
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.omni_process_full(ctx, params)


def process_streaming(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bool = True):
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.omni_process_streaming(ctx, params)


def sample(omni_streaming: ctypes.c_void_p, is_qwen: bool = True):
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.sample(omni_streaming)


def get_str(omni_streaming: ctypes.c_void_p, is_qwen: bool = True):
_lib = _lib_qwen if is_qwen else _lib_omni
return _lib.get_str(omni_streaming)

# OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);
def free(ctx: omni_context_p, is_qwen: bool = True):
_lib = _lib_qwen if is_qwen else _lib_omni
Expand All @@ -111,6 +126,15 @@ def free(ctx: omni_context_p, is_qwen: bool = True):
lib.omni_process_full.argtypes = [omni_context_p, omni_context_params_p]
lib.omni_process_full.restype = ctypes.c_char_p

lib.omni_process_streaming.argtypes = [omni_context_p, omni_context_params_p]
lib.omni_process_streaming.restype = ctypes.c_void_p

lib.sample.argtypes = [ctypes.c_void_p]
lib.sample.restype = ctypes.c_int32

lib.get_str.argtypes = [ctypes.c_void_p]
lib.get_str.restype = ctypes.c_char_p

# Configure free
lib.omni_free.argtypes = [omni_context_p]
lib.omni_free.restype = None
26 changes: 25 additions & 1 deletion nexa/gguf/llama/omni_vlm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,33 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p):
_lib.omnivlm_inference.restype = omni_char_p


def omnivlm_inference_streaming(prompt: omni_char_p, image_path: omni_char_p):
return _lib.omnivlm_inference_streaming(prompt, image_path)


_lib.omnivlm_inference_streaming.argtypes = [omni_char_p, omni_char_p]
_lib.omnivlm_inference_streaming.restype = ctypes.c_void_p


def sample(omni_vlm_streaming: ctypes.c_void_p):
return _lib.sample(omni_vlm_streaming)


_lib.sample.argtypes = [ctypes.c_void_p]
_lib.sample.restype = ctypes.c_int32


def get_str(omni_vlm_streaming: ctypes.c_void_p):
return _lib.get_str(omni_vlm_streaming)


_lib.get_str.argtypes = [ctypes.c_void_p]
_lib.get_str.restype = ctypes.c_char_p


def omnivlm_free():
return _lib.omnivlm_free()


_lib.omnivlm_free.argtypes = []
_lib.omnivlm_free.restype = None
_lib.omnivlm_free.restype = None
51 changes: 48 additions & 3 deletions nexa/gguf/nexa_inference_audio_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,21 @@ def run(self):
)

try:
with suppress_stdout_stderr():
response = self.inference(audio_path, user_input)
# with suppress_stdout_stderr():
# response = self.inference(audio_path, user_input)
first_chunk = True
for chunk in self.inference_streaming(audio_path, user_input):
if first_chunk:
stop_spinner(stop_event, spinner_thread)
first_chunk = False
if chunk == '\n':
chunk = ''
# print("FUCK")
print(chunk, end='', flush=True)
print() # '\n'
finally:
stop_spinner(stop_event, spinner_thread)

print(f"{response}")
self.cleanup()

except KeyboardInterrupt:
Expand Down Expand Up @@ -216,6 +225,42 @@ def inference(self, audio_path: str, prompt: str = "") -> str:
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")

def inference_streaming(self, audio_path: str, prompt: str = "") -> str:
"""
Perform a single inference with the audio language model.
"""
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")

try:
# Ensure audio is at 16kHz before processing
audio_path = self._ensure_16khz(audio_path)

self.ctx_params.file = ctypes.c_char_p(audio_path.encode("utf-8"))
self.ctx_params.prompt = ctypes.c_char_p(prompt.encode("utf-8"))

with suppress_stdout_stderr():
self.context = audio_lm_cpp.init_context(
ctypes.byref(self.ctx_params), is_qwen=self.is_qwen
)
if not self.context:
raise RuntimeError("Failed to load audio language model")
logging.debug("Model loaded successfully")

oss = audio_lm_cpp.process_streaming(
self.context, ctypes.byref(self.ctx_params), is_qwen=self.is_qwen
)
res = 0
while res >= 0:
res = audio_lm_cpp.sample(oss)
res_str = audio_lm_cpp.get_str(oss).decode('utf-8')

if '<|im_start|>' in res_str or '</s>' in res_str:
continue
yield res_str
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")

def cleanup(self):
"""
Explicitly cleanup resources
Expand Down
33 changes: 26 additions & 7 deletions nexa/gguf/nexa_inference_vlm_omni.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import ctypes
import logging
import os
Expand Down Expand Up @@ -126,20 +127,24 @@ def run(self):
try:
image_path = nexa_prompt("Image Path (required): ")
if not os.path.exists(image_path):
print(f"Image path: {image_path} not found, running omni VLM without image input.")
print(f"Image path: {image_path} not found, exiting...")
exit(1)
# Skip user input for OCR version
user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt()

stop_event, spinner_thread = start_spinner(
style="default",
message=""
)
first_chunk = True
for chunk in self.inference_streaming(user_input, image_path):
if first_chunk:
stop_spinner(stop_event, spinner_thread)
first_chunk = False
if chunk == '\n':
chunk = ''
print(chunk, end='', flush=True)

response = self.inference(user_input, image_path)

stop_spinner(stop_event, spinner_thread)

print(f"\nResponse: {response}")
except KeyboardInterrupt:
print("\nExiting...")
break
Expand All @@ -159,6 +164,20 @@ def inference(self, prompt: str, image_path: str):

return decoded_response

def inference_streaming(self, prompt: str, image_path: str):
with suppress_stdout_stderr():
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path)

res = 0
while res >= 0:
res = omni_vlm_cpp.sample(oss)
res_str = omni_vlm_cpp.get_str(oss).decode('utf-8')
if '<|im_start|>' in res_str or '</s>' in res_str:
continue
yield res_str

def __del__(self):
omni_vlm_cpp.omnivlm_free()

Expand Down Expand Up @@ -218,4 +237,4 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj
if args.streamlit:
inference.run_streamlit(model_path)
else:
inference.run()
inference.run()
Loading

0 comments on commit 20ab712

Please sign in to comment.