Skip to content

Commit

Permalink
refine demo
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang committed Nov 8, 2023
1 parent 49194be commit e31296b
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 82 deletions.
63 changes: 27 additions & 36 deletions demo_gpt4v_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@
with torch.autocast(device_type='cuda', dtype=torch.float16):
model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True)

history_images = []
history_masks = []
history_texts = []
@torch.no_grad()
def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs):
global history_images; history_images = []
global history_masks; history_masks = []
if slider < 1.5:
model_name = 'seem'
elif slider > 2.5:
Expand Down Expand Up @@ -119,68 +124,54 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs

if model_name == 'semantic-sam':
model = model_semsam
output = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)

elif model_name == 'sam':
model = model_sam
if mode == "Automatic":
output = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
elif mode == "Interactive":
output = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)

elif model_name == 'seem':
model = model_seem
if mode == "Automatic":
output = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
elif mode == "Interactive":
output = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)

# convert output to PIL image
output_img = Image.fromarray(output)
output_img.save('temp.jpg')
return output
history_masks.append(mask)
history_images.append(Image.fromarray(output))
return (output, [])


def gpt4v_response(message, history):
global history_images
global history_texts; history_texts = []
try:
res = request_gpt4v(message)
# save res to txt
with open('temp.txt', 'w') as f:
f.write(res)
res = request_gpt4v(message, history_images[0])
history_texts.append(res)
return res
except Exception as e:
return None

def highlight(mode, alpha, label_mode, anno_mode, *args, **kwargs):
# read temp.txt
with open('temp.txt', 'r') as f:
res = f.read()
# read temp_mask.jpg
mask = Image.open('temp_mask.jpg')
# convert mask to gray scale
mask = mask.convert('L')
res = history_texts[0]
# find the seperate numbers in sentence res
res = res.split(' ')
res = [r.replace('.','').replace(',','').replace(')','').replace('"','') for r in res]
# find all numbers in '[]'
res = [r for r in res if '[' in r]
res = [r.split('[')[1] for r in res]
res = [r.split(']')[0] for r in res]
res = [r for r in res if r.isdigit()]
# convert res to unique
res = list(set(res))
# draw mask
# resize image['image'] into mask size
# read temp.jpg
image = Image.open('temp.jpg')
image_out = image.resize(mask.size)
visual = Visualizer(image_out, metadata=metadata)
sections = []
for i, r in enumerate(res):
mask_i = np.copy(np.asarray(mask))
mask_i[mask_i != int(r)] = False
mask_i[mask_i == int(r)] = True
demo = visual.draw_binary_mask_with_number(mask_i, text='', label_mode=label_mode, alpha=0.6, anno_mode=["Mark", "Mask"])
del mask_i
if len(res) > 0:
im = demo.get_image()
return im
return image_out
mask_i = history_masks[0][int(r)-1]['segmentation']
sections.append((mask_i, r))
return (history_images[0], sections)

class ImageMask(gr.components.Image):
"""
Expand All @@ -205,7 +196,7 @@ def preprocess(self, x):
slider = gr.Slider(1, 3, value=1.8, label="Granularity") # info="Choose in [1, 1.5), [1.5, 2.5), [2.5, 3] for [seem, semantic-sam (multi-level), sam]"
mode = gr.Radio(['Automatic', 'Interactive', ], value='Automatic', label="Segmentation Mode")
anno_mode = gr.CheckboxGroup(choices=["Mark", "Mask", "Box"], value=['Mark'], label="Annotation Mode")
image_out = gr.Image(label="SoM Visual Prompt",type="pil", height=512)
image_out = gr.AnnotatedImage(label="SoM Visual Prompt",type="pil", height=512)
runBtn = gr.Button("Run")
highlightBtn = gr.Button("Highlight")
bot = gr.Chatbot(label="GPT-4V + SoM", height=256)
Expand Down
1 change: 0 additions & 1 deletion demo_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@

@torch.no_grad()
def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs):

if slider < 1.5:
model_name = 'seem'
elif slider > 2.5:
Expand Down
38 changes: 22 additions & 16 deletions gpt4v.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import base64
import requests
from io import BytesIO

# Get OpenAI API Key from environment variable
api_key = os.environ["OPENAI_API_KEY"]
Expand All @@ -10,31 +11,36 @@
}

metaprompt = '''
- You always generate the answer in markdown format. For any marks mentioned in your answer, please highlight them in a red color and bold font.
- For any marks mentioned in your answer, please highlight them with [].
'''

# Function to encode the image
def encode_image(image_path):
def encode_image_from_file(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def prepare_inputs(message):
def encode_image_from_pil(image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')

# Path to your image
image_path = "temp.jpg"
def prepare_inputs(message, image):

# Getting the base64 string
base64_image = encode_image(image_path)
# # Path to your image
# image_path = "temp.jpg"
# # Getting the base64 string
# base64_image = encode_image(image_path)
base64_image = encode_image_from_pil(image)

payload = {
"model": "gpt-4-vision-preview",
"messages": [
# {
# "role": "system",
# "content": [
# metaprompt
# ]
# },
{
"role": "system",
"content": [
metaprompt
]
},
{
"role": "user",
"content": [
Expand All @@ -51,13 +57,13 @@ def prepare_inputs(message):
]
}
],
"max_tokens": 300
"max_tokens": 800
}

return payload

def request_gpt4v(message):
payload = prepare_inputs(message)
def request_gpt4v(message, image):
payload = prepare_inputs(message, image)
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
res = response.json()['choices'][0]['message']['content']
return res
8 changes: 2 additions & 6 deletions task_adapter/sam/tasks/inference_sam_m2m_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,13 @@ def inference_sam_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, a
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# save the mask_map
mask_map = Image.fromarray(mask_map)
mask_map.save('temp_mask.jpg')

im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im
return im, sorted_anns


def remove_small_regions(
Expand Down
8 changes: 2 additions & 6 deletions task_adapter/sam/tasks/inference_sam_m2m_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,13 @@ def inference_sam_m2m_interactive(model, image, spatial_masks, text_size, label_
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# save the mask_map
mask_map = Image.fromarray(mask_map)
mask_map.save('temp_mask.jpg')

im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im
return im, sorted_anns


def remove_small_regions(
Expand Down
6 changes: 1 addition & 5 deletions task_adapter/seem/tasks/inference_seem_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,12 @@ def inference_seem_interactive(model, image, spatial_masks, text_size, label_mod
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# save the mask_map
mask_map = Image.fromarray(mask_map)
mask_map.save('temp_mask.jpg')

# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im
return im, sorted_anns


def remove_small_regions(
Expand Down
6 changes: 1 addition & 5 deletions task_adapter/seem/tasks/inference_seem_pano.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,12 @@ def inference_seem_pano(model, image, text_size, label_mode='1', alpha=0.1, anno
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# save the mask_map
mask_map = Image.fromarray(mask_map)
mask_map.save('temp_mask.jpg')

# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im
return im, sorted_anns


def remove_small_regions(
Expand Down
8 changes: 2 additions & 6 deletions task_adapter/semantic_sam/tasks/inference_semsam_m2m_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@ def inference_semsam_m2m_auto(model, image, level, all_classes, all_parts, thres
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# save the mask_map
mask_map = Image.fromarray(mask_map)
mask_map.save('temp_mask.jpg')

im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im
return im, sorted_anns


def remove_small_regions(
Expand Down
2 changes: 1 addition & 1 deletion task_adapter/utils/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ def draw_binary_mask_with_number(

if text is not None and has_valid_segment:
# lighter_color = tuple([x*0.2 for x in color])
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
lighter_color = [1,1,1] # self._change_color_brightness(color, brightness_factor=0.7)
self._draw_number_in_mask(binary_mask, text, lighter_color, label_mode)
return self.output

Expand Down

0 comments on commit e31296b

Please sign in to comment.