From e5b772954a48c0f2749b2bbc03da2c60bc671ed3 Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Wed, 11 Sep 2024 11:58:10 +0200 Subject: [PATCH] Remove flash_attn from pyproject --- node-hub/dora-qwenvl/dora_qwenvl/main.py | 39 ++++++++++++++++-------- node-hub/dora-qwenvl/pyproject.toml | 2 +- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/node-hub/dora-qwenvl/dora_qwenvl/main.py b/node-hub/dora-qwenvl/dora_qwenvl/main.py index 65112584c..d659b4df2 100644 --- a/node-hub/dora-qwenvl/dora_qwenvl/main.py +++ b/node-hub/dora-qwenvl/dora_qwenvl/main.py @@ -13,23 +13,32 @@ "Describe this image", ) -# default: Load the model on the available device(s) -model = Qwen2VLForConditionalGeneration.from_pretrained( - CUSTOM_MODEL_PATH, - torch_dtype="auto", - device_map="auto", - attn_implementation="flash_attention_2", -) +# Check if flash_attn is installed +try: + import flash_attn + + model = Qwen2VLForConditionalGeneration.from_pretrained( + CUSTOM_MODEL_PATH, + torch_dtype="auto", + device_map="auto", + attn_implementation="flash_attention_2", + ) +except ImportError: + model = Qwen2VLForConditionalGeneration.from_pretrained( + CUSTOM_MODEL_PATH, + torch_dtype="auto", + device_map="auto", + ) + # default processer processor = AutoProcessor.from_pretrained(DEFAULT_PATH) -def generate(image: np.array, question): +def generate(frames: dict, question): """ Generate the response to the question given the image using Qwen2 model. """ - image = Image.fromarray(image) messages = [ { @@ -38,7 +47,10 @@ def generate(image: np.array, question): { "type": "image", "image": image, - }, + } + for image in frames.values() + ] + + [ {"type": "text", "text": question}, ], } @@ -73,11 +85,11 @@ def generate(image: np.array, question): def main(): + pa.array([]) # initialize pyarrow array node = Node() question = DEFAULT_QUESTION - frame = None - pa.array([]) # initialize pyarrow array + frames = {} for event in node: event_type = event["type"] @@ -85,7 +97,7 @@ def main(): if event_type == "INPUT": event_id = event["id"] - if event_id == "image": + if "image" in event_id: storage = event["value"] metadata = event["metadata"] encoding = metadata["encoding"] @@ -112,6 +124,7 @@ def main(): pass else: raise RuntimeError(f"Unsupported image encoding: {encoding}") + frames[event_id] = Image.fromarray(frame) elif event_id == "tick": if frame is None: diff --git a/node-hub/dora-qwenvl/pyproject.toml b/node-hub/dora-qwenvl/pyproject.toml index d5ca3b7b3..8df2d80b5 100644 --- a/node-hub/dora-qwenvl/pyproject.toml +++ b/node-hub/dora-qwenvl/pyproject.toml @@ -19,7 +19,7 @@ torchvision = "^0.19" transformers = { git = "https://github.com/huggingface/transformers" } qwen-vl-utils = "^0.0.2" accelerate = "^0.33" -flash_attn = "^2.6.1" +# flash_attn = "^2.6.1" # Install using: pip install -U flash-attn --no-build-isolation [tool.poetry.scripts]