Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 421.image caption generator benchmark and added its data in bench… #218

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks-data
6 changes: 6 additions & 0 deletions benchmarks/400.inference/421.image-captioning/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"timeout": 60,
"memory": 256,
"languages": ["python"]
}

40 changes: 40 additions & 0 deletions benchmarks/400.inference/421.image-captioning/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import glob
import os

def buckets_count():
return (1, 1)

'''
Generate test, small, and large workload for image captioning benchmark.

:param data_dir: Directory where benchmark data is placed
:param size: Workload size
:param benchmarks_bucket: Storage container for the benchmark
:param input_paths: List of input paths
:param output_paths: List of output paths
:param upload_func: Upload function taking three params (bucket_idx, key, filepath)
'''
def generate_input(data_dir, size, benchmarks_bucket, input_paths, output_paths, upload_func):
input_files = glob.glob(os.path.join(data_dir, '*.jpg')) + glob.glob(os.path.join(data_dir, '*.png')) + glob.glob(os.path.join(data_dir, '*.jpeg'))
octonawish-akcodes marked this conversation as resolved.
Show resolved Hide resolved

if not input_files:
raise ValueError("No input files found in the provided directory.")

for file in input_files:
img = os.path.relpath(file, data_dir)
upload_func(0, img, file)

input_config = {
'object': {
'key': img,
'width': 200,
'height': 200
},
'bucket': {
'bucket': benchmarks_bucket,
'input': input_paths[0],
'output': output_paths[0]
}
}

return input_config
67 changes: 67 additions & 0 deletions benchmarks/400.inference/421.image-captioning/python/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import datetime
import io
import os
from urllib.parse import unquote_plus
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from . import storage

# Load the pre-trained ViT-GPT2 model
# Model URL: https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
# License: Apache 2.0 License (https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md)
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

model.eval()

client = storage.storage.get_instance()

def generate_caption(image_bytes):
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values

with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=16, num_beams=4)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

return generated_text

def handler(event):
bucket = event.get('bucket').get('bucket')
input_prefix = event.get('bucket').get('input')
output_prefix = event.get('bucket').get('output')
key = unquote_plus(event.get('object').get('key'))

download_begin = datetime.datetime.now()
img = client.download_stream(bucket, os.path.join(input_prefix, key))
download_end = datetime.datetime.now()

process_begin = datetime.datetime.now()
caption = generate_caption(img)
process_end = datetime.datetime.now()

upload_begin = datetime.datetime.now()
caption_file_name = os.path.splitext(key)[0] + '.txt'
caption_file_path = os.path.join(output_prefix, caption_file_name)
octonawish-akcodes marked this conversation as resolved.
Show resolved Hide resolved
client.upload_stream(bucket, caption_file_path, io.BytesIO(caption.encode('utf-8')))
upload_end = datetime.datetime.now()

download_time = (download_end - download_begin) / datetime.timedelta(microseconds=1)
upload_time = (upload_end - upload_begin) / datetime.timedelta(microseconds=1)
process_time = (process_end - process_begin) / datetime.timedelta(microseconds=1)

return {
'result': {
'bucket': bucket,
'key': caption_file_path
},
'measurement': {
'download_time': download_time,
'download_size': len(img),
'upload_time': upload_time,
'upload_size': len(caption.encode('utf-8')),
'compute_time': process_time
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers==4.44.2
octonawish-akcodes marked this conversation as resolved.
Show resolved Hide resolved
torch==2.4.0
pillow==10.4.0