Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update load_model function to support loading both model and whisper_model #327

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ModelType(Enum):
"phi2": "Phi-2:q4_0",
"phi3": "Phi-3-mini-128k-instruct:q4_0",
"phi3.5": "Phi-3.5-mini-instruct:q4_0",
"phi4": "Phi:q4_0",
"llama2-uncensored": "Llama2-7b-chat-uncensored:q4_0",
"llama3-uncensored": "Llama3-8B-Lexi-Uncensored:q4_K_M",
"openelm": "OpenELM-3B:q4_K_M",
Expand Down Expand Up @@ -413,6 +414,7 @@ class ModelType(Enum):
"Phi-3-mini-128k-instruct": ModelType.NLP,
"Phi-3-mini-4k-instruct": ModelType.NLP,
"Phi-3.5-mini-instruct": ModelType.NLP,
"Phi-4": ModelType.NLP,
"CodeQwen1.5-7B-Instruct": ModelType.NLP,
"Qwen2-0.5B-Instruct": ModelType.NLP,
"Qwen2-1.5B-Instruct": ModelType.NLP,
Expand Down
223 changes: 123 additions & 100 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
NEXA_RUN_PROJECTOR_MAP,
NEXA_RUN_OMNI_VLM_MAP,
NEXA_RUN_OMNI_VLM_PROJECTOR_MAP,
NEXA_RUN_MODEL_MAP_AUDIO_LM,
NEXA_RUN_AUDIO_LM_PROJECTOR_MAP,
NEXA_RUN_COMPLETION_TEMPLATE_MAP,
NEXA_RUN_MODEL_PRECISION_MAP,
NEXA_RUN_MODEL_MAP_FUNCTION_CALLING,
Expand Down Expand Up @@ -80,12 +82,14 @@
)

model = None
whisper_model = None
chat_format = None
completion_template = None
hostname = socket.gethostname()
chat_completion_system_prompt = [{"role": "system", "content": "You are a helpful assistant"}]
function_call_system_prompt = [{"role": "system", "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"}]
model_path = None
whisper_model_path = "faster-whisper-tiny" # by default, use tiny whisper model
n_ctx = None
is_local_path = False
model_type = None
Expand Down Expand Up @@ -215,6 +219,8 @@ class LoadModelRequest(BaseModel):
model_config = {
"protected_namespaces": ()
}
class LoadWhisperModelRequest(BaseModel):
whisper_model_path: str = "faster-whisper-tiny"

class DownloadModelRequest(BaseModel):
model_path: str = "llama3.2"
Expand Down Expand Up @@ -295,16 +301,16 @@ async def load_model():
raise ValueError("Multimodal and Audio models are not supported for Hugging Face")
downloaded_path, _ = pull_model(model_path, hf=is_huggingface, ms=is_modelscope)
else:
if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP:
if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP or model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
if model_path in NEXA_RUN_OMNI_VLM_MAP:
logging.info(f"Path is OmniVLM model: {model_path}")
downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path])
downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path])
else:
logging.info(f"Path is in NEXA_RUN_MODEL_MAP_VLM: {model_path}")
downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path])
elif model_path in NEXA_RUN_MODEL_MAP_VLM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path])
model_type = "Multimodal"
elif model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[model_path])
else:
downloaded_path, model_type = pull_model(model_path)

Expand Down Expand Up @@ -436,14 +442,6 @@ async def load_model():
n_gpu_layers=0, # hardcode to use CPU
)
logging.info(f"Model loaded as {model}")
elif model_type == "Audio":
with suppress_stdout_stderr():
model = WhisperModel(
downloaded_path,
device="cpu", # only support cpu for now because cuDNN needs to be installed on user's machine
compute_type="default"
)
logging.info(f"model loaded as {model}")
elif model_type == "AudioLM":
with suppress_stdout_stderr():
try:
Expand All @@ -463,7 +461,24 @@ async def load_model():
logging.info(f"model loaded as {model}")
else:
raise ValueError(f"Model {model_path} not found in Model Hub. If you are using local path, be sure to add --local_path and --model_type flags.")


async def load_whisper_model(custom_whisper_model_path=None):
global whisper_model, whisper_model_path
try:
if custom_whisper_model_path:
whisper_model_path = custom_whisper_model_path
downloaded_path, _ = pull_model(whisper_model_path)
with suppress_stdout_stderr():
whisper_model = WhisperModel(
downloaded_path,
device="cpu", # only support cpu for now because cuDNN needs to be installed on user's machine
compute_type="default"
)
logging.info(f"whisper model loaded as {whisper_model}")
except Exception as e:
logging.error(f"Error loading Whisper model: {e}")
raise ValueError(f"Failed to load Whisper model: {str(e)}")

def nexa_run_text_generation(
prompt, temperature, stop_words, max_new_tokens, top_k, top_p, logprobs=None, stream=False, is_chat_completion=True
) -> Dict[str, Any]:
Expand Down Expand Up @@ -710,16 +725,23 @@ def _resp_async_generator(streamer):
async def download_model(request: DownloadModelRequest):
"""Download a model from the model hub"""
try:
if request.model_path in NEXA_RUN_MODEL_MAP_VLM: # for Multimodal models
downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path])
if request.model_path in NEXA_RUN_MODEL_MAP_VLM or request.model_path in NEXA_RUN_OMNI_VLM_MAP or request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: # models and projectors
if request.model_path in NEXA_RUN_MODEL_MAP_VLM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path])
elif request.model_path in NEXA_RUN_OMNI_VLM_MAP:
downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path])
elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path])
return {
"status": "success",
"message": "Successfully downloaded multimodal model and projector",
"message": "Successfully downloaded model and projector",
"model_path": request.model_path,
"model_local_path": downloaded_path,
"projector_local_path": projector_downloaded_path,
"model_type": "Multimodal"
"model_type": model_type
}
else:
downloaded_path, model_type = pull_model(request.model_path)
Expand Down Expand Up @@ -768,6 +790,26 @@ async def load_different_model(request: LoadModelRequest):
detail=f"Failed to load model: {str(e)}"
)

@app.post("/v1/load_whisper_model", tags=["Model"])
async def load_different_whisper_model(request: LoadWhisperModelRequest):
"""Load a different Whisper model while maintaining the global model state"""
try:
global whisper_model_path
whisper_model_path = request.whisper_model_path
await load_whisper_model(custom_whisper_model_path=whisper_model_path)

return {
"status": "success",
"message": f"Successfully loaded Whisper model: {whisper_model_path}",
"model_type": "Audio",
}
except Exception as e:
logging.error(f"Error loading Whisper model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to load Whisper model: {str(e)}"
)

@app.get("/v1/list_models", tags=["Model"])
async def list_models():
"""List all models available in the model hub"""
Expand Down Expand Up @@ -1123,10 +1165,10 @@ async def process_audio(
temperature: Optional[float] = Query(0.0, description="Temperature for sampling.")
):
try:
if model_type != "Audio":
if not whisper_model:
raise HTTPException(
status_code=400,
detail="The model that is loaded is not an Audio model. Please use an Audio model."
detail="Whisper model is not loaded. Please load a Whisper model first."
)

with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio:
Expand All @@ -1146,7 +1188,7 @@ async def process_audio(
if task == "transcribe" and language:
task_params["language"] = language

segments, _ = model.transcribe(temp_audio_path, **task_params)
segments, _ = whisper_model.transcribe(temp_audio_path, **task_params)
result_text = "".join(segment.text for segment in segments)
return JSONResponse(content={"text": result_text})

Expand All @@ -1166,94 +1208,75 @@ async def processing_stream_audio(
language: Optional[str] = Query("auto", description="Language code (e.g., 'en', 'fr')"),
min_chunk: Optional[float] = Query(1.0, description="Minimum chunk duration for streaming"),
):
# Read the entire file into memory
audio_bytes = await file.read()
a_full = load_audio_from_bytes(audio_bytes)
duration = len(a_full) / SAMPLING_RATE

# Only include language parameter if task is "transcribe"
# For "translate", the language is always defined as "en"
if task == "transcribe" and language != "auto":
used_language = language
else:
used_language = None

warmup_audio = a_full[:SAMPLING_RATE] # first second
model.transcribe(warmup_audio)

streamer = StreamASRProcessor(model, task, used_language)

start = time.time()
beg = 0.0

def stream_generator():
nonlocal beg
while beg < duration:
now = time.time() - start
if now < beg + min_chunk:
time.sleep((beg + min_chunk) - now)
end = time.time() - start
if end > duration:
end = duration
try:
if not whisper_model:
raise HTTPException(
status_code=400,
detail="Whisper model is not loaded. Please load a Whisper model first."
)

chunk_samples = int((end - beg)*SAMPLING_RATE)
chunk_audio = a_full[int(beg*SAMPLING_RATE):int(beg*SAMPLING_RATE)+chunk_samples]
beg = end
# Read the entire file into memory
audio_bytes = await file.read()
a_full = load_audio_from_bytes(audio_bytes)
duration = len(a_full) / SAMPLING_RATE

streamer.insert_audio_chunk(chunk_audio)
o = streamer.process_iter()
# Only include language parameter if task is "transcribe"
# For "translate", the language is always defined as "en"
if task == "transcribe" and language != "auto":
used_language = language
else:
used_language = None

warmup_audio = a_full[:SAMPLING_RATE] # first second
whisper_model.transcribe(warmup_audio)

streamer = StreamASRProcessor(whisper_model, task, used_language)

start = time.time()
beg = 0.0

def stream_generator():
nonlocal beg
while beg < duration:
now = time.time() - start
if now < beg + min_chunk:
time.sleep((beg + min_chunk) - now)
end = time.time() - start
if end > duration:
end = duration

chunk_samples = int((end - beg)*SAMPLING_RATE)
chunk_audio = a_full[int(beg*SAMPLING_RATE):int(beg*SAMPLING_RATE)+chunk_samples]
beg = end

streamer.insert_audio_chunk(chunk_audio)
o = streamer.process_iter()
if o[0] is not None:
data = {
"emission_time_ms": (time.time()-start)*1000,
"segment_start_ms": o[0]*1000,
"segment_end_ms": o[1]*1000,
"text": o[2]
}
yield f"data: {json.dumps(data)}\n\n".encode("utf-8")

# Final flush
o = streamer.finish()
if o[0] is not None:
data = {
"emission_time_ms": (time.time()-start)*1000,
"segment_start_ms": o[0]*1000,
"segment_end_ms": o[1]*1000,
"text": o[2]
"text": o[2],
"final": True
}
yield f"data: {json.dumps(data)}\n\n".encode("utf-8")

# Final flush
o = streamer.finish()
if o[0] is not None:
data = {
"emission_time_ms": (time.time()-start)*1000,
"segment_start_ms": o[0]*1000,
"segment_end_ms": o[1]*1000,
"text": o[2],
"final": True
}
yield f"data: {json.dumps(data)}\n\n".encode("utf-8")

return StreamingResponse(stream_generator(), media_type="application/x-ndjson")

@app.post("/v1/audio/translations", tags=["Audio"])
async def translate_audio(
file: UploadFile = File(...),
beam_size: Optional[int] = Query(5, description="Beam size for translation"),
temperature: Optional[float] = Query(0.0, description="Temperature for sampling"),
):
try:
if model_type != "Audio":
raise HTTPException(
status_code=400,
detail="The model that is loaded is not an Audio model. Please use an Audio model for audio translation."
)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio:
temp_audio.write(await file.read())
temp_audio_path = temp_audio.name
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")

translate_params = {
"beam_size": beam_size,
"task": "translate",
"temperature": temperature,
"vad_filter": True
}
segments, _ = model.transcribe(temp_audio_path, **translate_params)
translation = "".join(segment.text for segment in segments)
return JSONResponse(content={"text": translation})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during translation: {str(e)}")
finally:
os.unlink(temp_audio_path)
logging.error(f"Error in audio processing stream: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/audiolm/chat/completions", tags=["AudioLM"])
async def audio_chat_completions(
Expand Down