Skip to content

Commit

Permalink
Remove flash_attn from pyproject
Browse files Browse the repository at this point in the history
  • Loading branch information
haixuanTao committed Sep 11, 2024
1 parent 4456aec commit e5b7729
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
39 changes: 26 additions & 13 deletions node-hub/dora-qwenvl/dora_qwenvl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,32 @@
"Describe this image",
)

# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
CUSTOM_MODEL_PATH,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
# Check if flash_attn is installed
try:
import flash_attn

model = Qwen2VLForConditionalGeneration.from_pretrained(
CUSTOM_MODEL_PATH,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
except ImportError:
model = Qwen2VLForConditionalGeneration.from_pretrained(
CUSTOM_MODEL_PATH,
torch_dtype="auto",
device_map="auto",
)


# default processer

Check warning on line 34 in node-hub/dora-qwenvl/dora_qwenvl/main.py

View workflow job for this annotation

GitHub Actions / Typos

"processer" should be "processor".
processor = AutoProcessor.from_pretrained(DEFAULT_PATH)


def generate(image: np.array, question):
def generate(frames: dict, question):
"""
Generate the response to the question given the image using Qwen2 model.
"""
image = Image.fromarray(image)

messages = [
{
Expand All @@ -38,7 +47,10 @@ def generate(image: np.array, question):
{
"type": "image",
"image": image,
},
}
for image in frames.values()
]
+ [
{"type": "text", "text": question},
],
}
Expand Down Expand Up @@ -73,19 +85,19 @@ def generate(image: np.array, question):


def main():
pa.array([]) # initialize pyarrow array
node = Node()

question = DEFAULT_QUESTION
frame = None
pa.array([]) # initialize pyarrow array
frames = {}

for event in node:
event_type = event["type"]

if event_type == "INPUT":
event_id = event["id"]

if event_id == "image":
if "image" in event_id:
storage = event["value"]
metadata = event["metadata"]
encoding = metadata["encoding"]
Expand All @@ -112,6 +124,7 @@ def main():
pass
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")
frames[event_id] = Image.fromarray(frame)

elif event_id == "tick":
if frame is None:
Expand Down
2 changes: 1 addition & 1 deletion node-hub/dora-qwenvl/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ torchvision = "^0.19"
transformers = { git = "https://github.com/huggingface/transformers" }
qwen-vl-utils = "^0.0.2"
accelerate = "^0.33"
flash_attn = "^2.6.1"
# flash_attn = "^2.6.1" # Install using: pip install -U flash-attn --no-build-isolation


[tool.poetry.scripts]
Expand Down

0 comments on commit e5b7729

Please sign in to comment.