Skip to content

Commit

Permalink
Fix features!
Browse files Browse the repository at this point in the history
  • Loading branch information
haiwei-luo committed Jul 15, 2024
1 parent eb5c904 commit 9ec5045
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
2 changes: 1 addition & 1 deletion common/vision/lasr_vision_clip/nodes/vqa
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class VqaService:
"""
possible_answers = request.possible_answers
answer, cos_score, annotated_img = query_image(
self._model, request.image_raw, possible_answers, annotate=True
request.image_raw, self._model, possible_answers, annotate=True
)

self._debug_pub.publish(annotated_img)
Expand Down
4 changes: 1 addition & 3 deletions common/vision/lasr_vision_feature_extraction/nodes/service
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def detect(
head_frame,
torso_frame,
full_frame,
head_mask,
torso_mask,
request.image_raw,
image_raw=request.image_raw,
cloth_predictor=cloth_predictor,
)
response = TorchFaceFeatureDetectionDescriptionResponse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,6 @@ def predict_frame(
head_frame,
torso_frame,
full_frame,
head_mask,
torso_mask,
cloth_predictor,
image_raw,
clip_service: rospy.ServiceProxy = rospy.ServiceProxy(
Expand All @@ -571,6 +569,7 @@ def predict_frame(
rst_person = {
"glasses": -0.5,
"hat": -0.5,
"hair_shape": "short hair"
}

glasses_query = VqaRequest(
Expand All @@ -583,14 +582,24 @@ def predict_frame(
image_raw=image_raw,
)

hair_query = VqaRequest(
possible_answers=["A person with short hair", "A person with long hair"],
image_raw=image_raw,
)

glasses_response = clip_service(glasses_query)
hat_response = clip_service(hat_query)

if glasses_response.answer == "A person wearing glasses":
rst_person["glasses"] = 0.5
hat_response = clip_service(hat_query)
if hat_response.answer == "A person wearing a hat":
rst_person["hat"] = 0.5

hair_response = clip_service(hair_query)

if hair_response.answer == "A person with long hair":
rst_person["hair_shape"] = "long hair"

result = {
**rst_person,
**rst_cloth,
Expand Down
3 changes: 1 addition & 2 deletions tasks/receptionist/launch/setup.launch
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
<node pkg="lasr_vision_feature_extraction" type="service" name="torch_service" output="screen"/>
<node pkg="lasr_vision_clip" name="clip_service" type="vqa" output="screen"/>
<arg name="debug" default="true" />
<node pkg="lasr_vision_bodypix" type="mask_service.py" name="bodypix_service" output="screen" args="--debug $(arg debug)"/>
<node pkg="lasr_vision_bodypix" type="keypoint_service.py" name="bodypix__keypoint_service" output="screen" args="--debug $(arg debug)"/>
<node pkg="lasr_vision_bodypix" type="bodypix_services.py" name="bodypix_service" output="screen" args="--debug $(arg debug)"/>
<node pkg="lasr_vision_deepface" type="service" name="deepface_service" output="screen"/>

</launch>
17 changes: 10 additions & 7 deletions tasks/receptionist/src/receptionist/states/introduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,20 @@ def stringify_guest_data(
},
)


relevant_guest_data["attributes"]["has_hair"] = 0.5

guest_str = f"{relevant_guest_data['name']}, their favourite drink is {relevant_guest_data['drink']}. "

if not relevant_guest_data["detection"] or not describe_features:
return guest_str

filtered_attributes = {}
# filtered_attributes["hair"] = {
# "confidence": relevant_guest_data["attributes"]["has_hair"],
# "hair_shape": relevant_guest_data["attributes"]["hair_shape"],
# "hair_colour": relevant_guest_data["attributes"]["hair_colour"],
# }
filtered_attributes["hair"] = {
"confidence": relevant_guest_data["attributes"]["has_hair"],
"hair_shape": relevant_guest_data["attributes"]["hair_shape"],
# "hair_colour": relevant_guest_data["attributes"]["hair_colour"],
}

most_confident_clothes = find_most_confident_clothes(
relevant_guest_data,
Expand Down Expand Up @@ -117,8 +120,8 @@ def stringify_guest_data(

if attribute_name == "hair":
hair_shape = attribute_value["hair_shape"]
hair_colour = attribute_value["hair_colour"]
guest_str += f"They have {hair_shape} and {hair_colour}. "
# hair_colour = attribute_value["hair_colour"]
guest_str += f"They have {hair_shape}. "
elif attribute_name == "facial_hair":
if confidence < 0:
guest_str += "They don't have facial hair. "
Expand Down

0 comments on commit 9ec5045

Please sign in to comment.