Skip to content

Commit

Permalink
Merge pull request #70 from NexaAI/david/bugfix
Browse files Browse the repository at this point in the history
Fix projector name (David)
  • Loading branch information
zhiyuan8 authored Sep 5, 2024
2 parents 5f7a44e + 162eec5 commit 5c25488
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 52 deletions.
76 changes: 43 additions & 33 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,32 @@ def run_ggml_inference(args):

stop_words = kwargs.pop("stop_words", [])

if run_type == "NLP":
from nexa.gguf.nexa_inference_text import NexaTextInference
inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs)
elif run_type == "Computer Vision":
from nexa.gguf.nexa_inference_image import NexaImageInference
inference = NexaImageInference(model_path=model_path, local_path=local_path, **kwargs)
if hasattr(args, 'streamlit') and args.streamlit:
inference.run_streamlit(model_path)
elif args.img2img:
inference.run_img2img()
try:
if run_type == "NLP":
from nexa.gguf.nexa_inference_text import NexaTextInference
inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs)
elif run_type == "Computer Vision":
from nexa.gguf.nexa_inference_image import NexaImageInference
inference = NexaImageInference(model_path=model_path, local_path=local_path, **kwargs)
if hasattr(args, 'streamlit') and args.streamlit:
inference.run_streamlit(model_path)
elif args.img2img:
inference.run_img2img()
else:
inference.run_txt2img()
return
elif run_type == "Multimodal":
from nexa.gguf.nexa_inference_vlm import NexaVLMInference
inference = NexaVLMInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs)
elif run_type == "Audio":
from nexa.gguf.nexa_inference_voice import NexaVoiceInference
inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs)
else:
inference.run_txt2img()
print(f"Unknown task: {run_type}. Skipping inference.")
return
except Exception as e:
print(f"Error loading GGUF models, please refer to our docs to install nexaai package: https://docs.nexaai.com/getting-started/installation ")
return
elif run_type == "Multimodal":
from nexa.gguf.nexa_inference_vlm import NexaVLMInference
inference = NexaVLMInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs)
elif run_type == "Audio":
from nexa.gguf.nexa_inference_voice import NexaVoiceInference
inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs)
else:
raise ValueError(f"Unknown task: {run_type}")

if hasattr(args, 'streamlit') and args.streamlit:
inference.run_streamlit(model_path)
Expand All @@ -49,20 +54,25 @@ def run_onnx_inference(args):
from nexa.general import pull_model
local_path, run_type = pull_model(model_path)

if run_type == "NLP":
from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference
inference = NexaTextOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "Computer Vision":
from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference
inference = NexaImageOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "Audio":
from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference
inference = NexaVoiceOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "TTS":
from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference
inference = NexaTTSOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
else:
raise ValueError(f"Unknown task: {run_type}")
try:
if run_type == "NLP":
from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference
inference = NexaTextOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "Computer Vision":
from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference
inference = NexaImageOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "Audio":
from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference
inference = NexaVoiceOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "TTS":
from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference
inference = NexaTTSOnnxInference(model_path=model_path, local_path=local_path, **kwargs)
else:
print(f"Unknown task: {run_type}. Skipping inference.")
return
except Exception as e:
print(f"Error loading ONNX models, please refer to our docs to install nexaai[onnx] package: https://docs.nexaai.com/getting-started/installation ")
return

if hasattr(args, 'streamlit') and args.streamlit:
inference.run_streamlit(model_path)
Expand Down
2 changes: 1 addition & 1 deletion nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@

NEXA_RUN_PROJECTOR_MAP = {
"nanollava": "nanoLLaVA:projector-fp16",
"nanoLLaVA:fp16": "nanoLLaVA:project-fp16",
"nanoLLaVA:fp16": "nanoLLaVA:projector-fp16",
"llava-phi3": "llava-phi-3-mini:projector-q4_0",
"llava-phi-3-mini:q4_0": "llava-phi-3-mini:projector-q4_0",
"llava-phi-3-mini:fp16": "llava-phi-3-mini:projector-fp16",
Expand Down
2 changes: 1 addition & 1 deletion nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _retry(self, func, *args, **kwargs):
except Exception as e:
logging.error(f"Attempt {attempt + 1} failed with error: {e}")
time.sleep(1)
logging.error("All retry attempts failed.")
print("All retry attempts failed becase of Out of Memory error, Try to use smaller models...")
return None

def txt2img(
Expand Down
16 changes: 14 additions & 2 deletions nexa/gguf/streamlit/streamlit_image_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from nexa.general import pull_model
import streamlit as st
from nexa.gguf.nexa_inference_image import NexaImageInference
import io

default_model = sys.argv[1]

Expand Down Expand Up @@ -106,5 +107,16 @@ def generate_images(nexa_model: NexaImageInference, prompt: str, negative_prompt
st.session_state.nexa_model, prompt, negative_prompt
)
st.success("Images generated successfully!")
for image in images:
st.image(image, caption="Generated Image", use_column_width=True)
for i, image in enumerate(images):
st.image(image, caption=f"Generated Image", use_column_width=True)

img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()

st.download_button(
label=f"Download Image",
data=img_byte_arr,
file_name=f"generated_image.png",
mime="image/png"
)
30 changes: 17 additions & 13 deletions nexa/onnx/streamlit/streamlit_image_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import streamlit as st

from optimum.onnxruntime import ORTLatentConsistencyModelPipeline
from nexa.general import pull_model
from nexa.onnx.nexa_inference_image import NexaImageInference

Expand All @@ -15,11 +15,11 @@ def load_model(model_path):
local_path, run_type = pull_model(model_path)
nexa_model = NexaImageInference(model_path=model_path, local_path=local_path)

if nexa_model.downloaded_onnx_folder is None:
if nexa_model.download_onnx_folder is None:
st.error("Failed to download the model. Please check the model path.")
return None

nexa_model._load_model(nexa_model.downloaded_onnx_folder)
nexa_model._load_model(nexa_model.download_onnx_folder)
return nexa_model


Expand All @@ -30,17 +30,21 @@ def generate_images(nexa_model: NexaImageInference, prompt, negative_prompt):

generator = np.random.RandomState(nexa_model.params["random_seed"])

images = nexa_model.pipeline(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
num_inference_steps=nexa_model.params["num_inference_steps"],
num_images_per_prompt=nexa_model.params["num_images_per_prompt"],
height=nexa_model.params["height"],
width=nexa_model.params["width"],
generator=generator,
guidance_scale=nexa_model.params["guidance_scale"],
).images
is_lcm_pipeline = isinstance(nexa_model.pipeline, ORTLatentConsistencyModelPipeline)

pipeline_kwargs = {
"prompt": prompt,
"num_inference_steps": nexa_model.params["num_inference_steps"],
"num_images_per_prompt": nexa_model.params["num_images_per_prompt"],
"height": nexa_model.params["height"],
"width": nexa_model.params["width"],
"generator": generator,
"guidance_scale": nexa_model.params["guidance_scale"],
}
if not is_lcm_pipeline and negative_prompt:
pipeline_kwargs["negative_prompt"] = negative_prompt

images = nexa_model.pipeline(**pipeline_kwargs).images
return images


Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ dependencies = [
"prompt_toolkit",
"tqdm", # Shared dependencies
"tabulate",
"streamlit",
"streamlit>=1.37.1",
"streamlit-audiorec",
"python-multipart",
"streamlit-audiorec",
"cmake",
]
classifiers = [
Expand Down

0 comments on commit 5c25488

Please sign in to comment.