From e31296b0c87be1ecb336bc0eaf658f8b4ec83675 Mon Sep 17 00:00:00 2001 From: jwyang Date: Wed, 8 Nov 2023 23:49:51 +0000 Subject: [PATCH] refine demo --- demo_gpt4v_som.py | 63 ++++++++----------- demo_som.py | 1 - gpt4v.py | 38 ++++++----- .../sam/tasks/inference_sam_m2m_auto.py | 8 +-- .../tasks/inference_sam_m2m_interactive.py | 8 +-- .../seem/tasks/inference_seem_interactive.py | 6 +- .../seem/tasks/inference_seem_pano.py | 6 +- .../tasks/inference_semsam_m2m_auto.py | 8 +-- task_adapter/utils/visualizer.py | 2 +- 9 files changed, 58 insertions(+), 82 deletions(-) diff --git a/demo_gpt4v_som.py b/demo_gpt4v_som.py index 2614474c..01231d23 100644 --- a/demo_gpt4v_som.py +++ b/demo_gpt4v_som.py @@ -76,8 +76,13 @@ with torch.autocast(device_type='cuda', dtype=torch.float16): model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True) +history_images = [] +history_masks = [] +history_texts = [] @torch.no_grad() def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs): + global history_images; history_images = [] + global history_masks; history_masks = [] if slider < 1.5: model_name = 'seem' elif slider > 2.5: @@ -119,68 +124,54 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs if model_name == 'semantic-sam': model = model_semsam - output = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs) + output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs) elif model_name == 'sam': model = model_sam if mode == "Automatic": - output = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode) + output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode) elif mode == "Interactive": - output = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode) + output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode) elif model_name == 'seem': model = model_seem if mode == "Automatic": - output = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode) + output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode) elif mode == "Interactive": - output = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode) + output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode) # convert output to PIL image - output_img = Image.fromarray(output) - output_img.save('temp.jpg') - return output + history_masks.append(mask) + history_images.append(Image.fromarray(output)) + return (output, []) def gpt4v_response(message, history): + global history_images + global history_texts; history_texts = [] try: - res = request_gpt4v(message) - # save res to txt - with open('temp.txt', 'w') as f: - f.write(res) + res = request_gpt4v(message, history_images[0]) + history_texts.append(res) return res except Exception as e: return None def highlight(mode, alpha, label_mode, anno_mode, *args, **kwargs): - # read temp.txt - with open('temp.txt', 'r') as f: - res = f.read() - # read temp_mask.jpg - mask = Image.open('temp_mask.jpg') - # convert mask to gray scale - mask = mask.convert('L') + res = history_texts[0] # find the seperate numbers in sentence res res = res.split(' ') res = [r.replace('.','').replace(',','').replace(')','').replace('"','') for r in res] + # find all numbers in '[]' + res = [r for r in res if '[' in r] + res = [r.split('[')[1] for r in res] + res = [r.split(']')[0] for r in res] res = [r for r in res if r.isdigit()] - # convert res to unique res = list(set(res)) - # draw mask - # resize image['image'] into mask size - # read temp.jpg - image = Image.open('temp.jpg') - image_out = image.resize(mask.size) - visual = Visualizer(image_out, metadata=metadata) + sections = [] for i, r in enumerate(res): - mask_i = np.copy(np.asarray(mask)) - mask_i[mask_i != int(r)] = False - mask_i[mask_i == int(r)] = True - demo = visual.draw_binary_mask_with_number(mask_i, text='', label_mode=label_mode, alpha=0.6, anno_mode=["Mark", "Mask"]) - del mask_i - if len(res) > 0: - im = demo.get_image() - return im - return image_out + mask_i = history_masks[0][int(r)-1]['segmentation'] + sections.append((mask_i, r)) + return (history_images[0], sections) class ImageMask(gr.components.Image): """ @@ -205,7 +196,7 @@ def preprocess(self, x): slider = gr.Slider(1, 3, value=1.8, label="Granularity") # info="Choose in [1, 1.5), [1.5, 2.5), [2.5, 3] for [seem, semantic-sam (multi-level), sam]" mode = gr.Radio(['Automatic', 'Interactive', ], value='Automatic', label="Segmentation Mode") anno_mode = gr.CheckboxGroup(choices=["Mark", "Mask", "Box"], value=['Mark'], label="Annotation Mode") -image_out = gr.Image(label="SoM Visual Prompt",type="pil", height=512) +image_out = gr.AnnotatedImage(label="SoM Visual Prompt",type="pil", height=512) runBtn = gr.Button("Run") highlightBtn = gr.Button("Highlight") bot = gr.Chatbot(label="GPT-4V + SoM", height=256) diff --git a/demo_som.py b/demo_som.py index fcea1539..bbeb3af7 100644 --- a/demo_som.py +++ b/demo_som.py @@ -62,7 +62,6 @@ @torch.no_grad() def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs): - if slider < 1.5: model_name = 'seem' elif slider > 2.5: diff --git a/gpt4v.py b/gpt4v.py index 342e1ed2..1da28c39 100644 --- a/gpt4v.py +++ b/gpt4v.py @@ -1,6 +1,7 @@ import os import base64 import requests +from io import BytesIO # Get OpenAI API Key from environment variable api_key = os.environ["OPENAI_API_KEY"] @@ -10,31 +11,36 @@ } metaprompt = ''' -- You always generate the answer in markdown format. For any marks mentioned in your answer, please highlight them in a red color and bold font. +- For any marks mentioned in your answer, please highlight them with []. ''' # Function to encode the image -def encode_image(image_path): +def encode_image_from_file(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') -def prepare_inputs(message): +def encode_image_from_pil(image): + buffered = BytesIO() + image.save(buffered, format="JPEG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') - # Path to your image - image_path = "temp.jpg" +def prepare_inputs(message, image): - # Getting the base64 string - base64_image = encode_image(image_path) + # # Path to your image + # image_path = "temp.jpg" + # # Getting the base64 string + # base64_image = encode_image(image_path) + base64_image = encode_image_from_pil(image) payload = { "model": "gpt-4-vision-preview", "messages": [ - # { - # "role": "system", - # "content": [ - # metaprompt - # ] - # }, + { + "role": "system", + "content": [ + metaprompt + ] + }, { "role": "user", "content": [ @@ -51,13 +57,13 @@ def prepare_inputs(message): ] } ], - "max_tokens": 300 + "max_tokens": 800 } return payload -def request_gpt4v(message): - payload = prepare_inputs(message) +def request_gpt4v(message, image): + payload = prepare_inputs(message, image) response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) res = response.json()['choices'][0]['message']['content'] return res diff --git a/task_adapter/sam/tasks/inference_sam_m2m_auto.py b/task_adapter/sam/tasks/inference_sam_m2m_auto.py index 30d74801..d51cf758 100644 --- a/task_adapter/sam/tasks/inference_sam_m2m_auto.py +++ b/task_adapter/sam/tasks/inference_sam_m2m_auto.py @@ -51,17 +51,13 @@ def inference_sam_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, a # assign the mask to the mask_map mask_map[mask == 1] = label label += 1 - im = demo.get_image() - # save the mask_map - mask_map = Image.fromarray(mask_map) - mask_map.save('temp_mask.jpg') - + im = demo.get_image() # fig=plt.figure(figsize=(10, 10)) # plt.imshow(image_ori) # show_anns(outputs) # fig.canvas.draw() # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - return im + return im, sorted_anns def remove_small_regions( diff --git a/task_adapter/sam/tasks/inference_sam_m2m_interactive.py b/task_adapter/sam/tasks/inference_sam_m2m_interactive.py index 7025d994..5752138e 100644 --- a/task_adapter/sam/tasks/inference_sam_m2m_interactive.py +++ b/task_adapter/sam/tasks/inference_sam_m2m_interactive.py @@ -169,17 +169,13 @@ def inference_sam_m2m_interactive(model, image, spatial_masks, text_size, label_ # assign the mask to the mask_map mask_map[mask == 1] = label label += 1 - im = demo.get_image() - # save the mask_map - mask_map = Image.fromarray(mask_map) - mask_map.save('temp_mask.jpg') - + im = demo.get_image() # fig=plt.figure(figsize=(10, 10)) # plt.imshow(image_ori) # show_anns(outputs) # fig.canvas.draw() # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - return im + return im, sorted_anns def remove_small_regions( diff --git a/task_adapter/seem/tasks/inference_seem_interactive.py b/task_adapter/seem/tasks/inference_seem_interactive.py index 09a321b5..a4b3ce9a 100644 --- a/task_adapter/seem/tasks/inference_seem_interactive.py +++ b/task_adapter/seem/tasks/inference_seem_interactive.py @@ -118,16 +118,12 @@ def inference_seem_interactive(model, image, spatial_masks, text_size, label_mod mask_map[mask == 1] = label label += 1 im = demo.get_image() - # save the mask_map - mask_map = Image.fromarray(mask_map) - mask_map.save('temp_mask.jpg') - # fig=plt.figure(figsize=(10, 10)) # plt.imshow(image_ori) # show_anns(outputs) # fig.canvas.draw() # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - return im + return im, sorted_anns def remove_small_regions( diff --git a/task_adapter/seem/tasks/inference_seem_pano.py b/task_adapter/seem/tasks/inference_seem_pano.py index be47b701..d75af481 100644 --- a/task_adapter/seem/tasks/inference_seem_pano.py +++ b/task_adapter/seem/tasks/inference_seem_pano.py @@ -113,16 +113,12 @@ def inference_seem_pano(model, image, text_size, label_mode='1', alpha=0.1, anno mask_map[mask == 1] = label label += 1 im = demo.get_image() - # save the mask_map - mask_map = Image.fromarray(mask_map) - mask_map.save('temp_mask.jpg') - # fig=plt.figure(figsize=(10, 10)) # plt.imshow(image_ori) # show_anns(outputs) # fig.canvas.draw() # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - return im + return im, sorted_anns def remove_small_regions( diff --git a/task_adapter/semantic_sam/tasks/inference_semsam_m2m_auto.py b/task_adapter/semantic_sam/tasks/inference_semsam_m2m_auto.py index 210be73f..a939a3c8 100644 --- a/task_adapter/semantic_sam/tasks/inference_semsam_m2m_auto.py +++ b/task_adapter/semantic_sam/tasks/inference_semsam_m2m_auto.py @@ -56,17 +56,13 @@ def inference_semsam_m2m_auto(model, image, level, all_classes, all_parts, thres # assign the mask to the mask_map mask_map[mask == 1] = label label += 1 - im = demo.get_image() - # save the mask_map - mask_map = Image.fromarray(mask_map) - mask_map.save('temp_mask.jpg') - + im = demo.get_image() # fig=plt.figure(figsize=(10, 10)) # plt.imshow(image_ori) # show_anns(outputs) # fig.canvas.draw() # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - return im + return im, sorted_anns def remove_small_regions( diff --git a/task_adapter/utils/visualizer.py b/task_adapter/utils/visualizer.py index 6786ee9f..bd78a98e 100644 --- a/task_adapter/utils/visualizer.py +++ b/task_adapter/utils/visualizer.py @@ -1173,7 +1173,7 @@ def draw_binary_mask_with_number( if text is not None and has_valid_segment: # lighter_color = tuple([x*0.2 for x in color]) - lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + lighter_color = [1,1,1] # self._change_color_brightness(color, brightness_factor=0.7) self._draw_number_in_mask(binary_mask, text, lighter_color, label_mode) return self.output