Skip to content

Commit

Permalink
Updated feature extractors and corrected the labels and threshold mis…
Browse files Browse the repository at this point in the history
…match error. (LASR-at-Home#222)

Co-authored-by: Benteng Ma <[email protected]>
Co-authored-by: Jared Swift <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2024
1 parent 3dcff75 commit 842f631
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 74 deletions.
3 changes: 1 addition & 2 deletions common/vision/lasr_vision_feature_extraction/nodes/service
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ if __name__ == '__main__':
head_model = lasr_vision_feature_extraction.load_face_classifier_model()
head_predictor = lasr_vision_feature_extraction.Predictor(head_model, torch.device('cpu'), CelebAMaskHQCategoriesAndAttributes)
cloth_model = lasr_vision_feature_extraction.load_cloth_classifier_model()
cloth_model.return_bbox = False # unify returns
cloth_predictor = lasr_vision_feature_extraction.Predictor(cloth_model, torch.device('cpu'), DeepFashion2GeneralizedCategoriesAndAttributes)
cloth_predictor = lasr_vision_feature_extraction.ClothPredictor(cloth_model, torch.device('cpu'), DeepFashion2GeneralizedCategoriesAndAttributes)
rospy.init_node('torch_service')
rospy.Service('/torch/detect/face_features', TorchFaceFeatureDetectionDescription, detect)
rospy.loginfo('Torch service started')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,20 +346,16 @@ def __init__(

self._thresholds_mask: list[float] = []
self._thresholds_pred: list[float] = []
for key in sorted(
list(self.categories_and_attributes.merged_categories.keys())
):
for key in sorted(list(self.categories_and_attributes.thresholds_mask.keys())):
self._thresholds_mask.append(
self.categories_and_attributes.thresholds_mask[key]
)
for attribute in self.categories_and_attributes.attributes:
if attribute not in self.categories_and_attributes.avoided_attributes:
self._thresholds_pred.append(
self.categories_and_attributes.thresholds_pred[attribute]
)
for key in sorted(list(self.categories_and_attributes.thresholds_pred.keys())):
self._thresholds_pred.append(
self.categories_and_attributes.thresholds_pred[key]
)

def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
mean_val = np.mean(rgb_image)
image_tensor = (
torch.from_numpy(rgb_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
)
Expand Down Expand Up @@ -391,6 +387,62 @@ def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
return image_obj


class ClothPredictor(Predictor):
def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
general_categories = [
"top",
"down",
"outwear",
"dress",
]
categories = [
"top",
"down",
"outwear",
"dress",
"short sleeve top",
"long sleeve top",
"short sleeve outwear",
"long sleeve outwear",
"vest",
"sling",
"shorts",
"trousers",
"skirt",
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress",
]
image_tensor = (
torch.from_numpy(rgb_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
)
pred_masks, pred_classes, pred_bboxes = self.model(image_tensor)
# Apply binary erosion and dilation to the masks
pred_masks = binary_erosion_dilation(
pred_masks,
thresholds=self._thresholds_pred,
erosion_iterations=1,
dilation_iterations=1,
)
pred_masks = pred_masks.detach().squeeze(0).numpy().astype(np.uint8)
mask_list = [pred_masks[i, :, :] for i in range(pred_masks.shape[0])]
pred_classes = pred_classes.detach().squeeze(0).numpy()
class_list = [pred_classes[i].item() for i in range(pred_classes.shape[0])]
mask_dict = {}
for i, mask in enumerate(mask_list):
mask_dict[categories[i]] = mask
attribute_dict = {}
class_list_iter = class_list.__iter__()
for attribute in categories:
# if attribute not in self.categories_and_attributes.avoided_attributes:
attribute_dict[attribute] = class_list_iter.__next__()
image_obj = ImageOfCloth(
rgb_image, mask_dict, attribute_dict, self.categories_and_attributes
)
return image_obj


def load_face_classifier_model():
cat_layers = CelebAMaskHQCategoriesAndAttributes.merged_categories.keys().__len__()
segment_model = UNetWithResnetEncoder(num_classes=cat_layers)
Expand Down Expand Up @@ -548,9 +600,9 @@ def binary_erosion_dilation(

# Check if the length of thresholds matches the number of channels
if len(thresholds) != tensor.size(1):
# the error should be here, just removed for now since there's some other bug I haven't fixed.
# raise ValueError(f"Length of thresholds {len(thresholds)} must match the number of channels {tensor.size(1)}")
thresholds = [0.5 for _ in range(tensor.size(1))]
raise ValueError(
f"Length of thresholds {len(thresholds)} must match the number of channels {tensor.size(1)}"
)

# Binary thresholding
for i, threshold in enumerate(thresholds):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,19 @@ class CelebAMaskHQCategoriesAndAttributes(CategoriesAndAttributes):
for key in sorted(merged_categories.keys()):
thresholds_mask[key] = 0.5
for key in attributes + mask_labels:
thresholds_pred[key] = 0.5
if key not in avoided_attributes:
thresholds_pred[key] = 0.5

# set specific thresholds:
thresholds_mask["eye_g"] = 0.25
thresholds_pred["Eyeglasses"] = 0.25
thresholds_mask["eye_g"] = 0.5
thresholds_pred["Eyeglasses"] = 0.5
thresholds_pred["Wearing_Earrings"] = 0.5
thresholds_pred["Wearing_Necklace"] = 0.5
thresholds_pred["Wearing_Necktie"] = 0.5


class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
mask_categories = [
"top",
"down",
"outwear",
"dress",
"short sleeve top",
"long sleeve top",
"short sleeve outwear",
Expand Down Expand Up @@ -212,19 +209,19 @@ class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
"sling dress",
],
}
mask_labels = [
"top",
"down",
"outwear",
"dress",
]
_categories_to_merge = []
for key in sorted(list(merged_categories.keys())):
for cat in merged_categories[key]:
_categories_to_merge.append(cat)
for key in mask_categories:
if key not in _categories_to_merge:
merged_categories[key] = [key]
mask_labels = [
"top",
"down",
"outwear",
"dress",
]
selective_attributes = {}
plane_attributes = []
avoided_attributes = []
Expand All @@ -250,7 +247,6 @@ class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
# set default thresholds:
for key in sorted(merged_categories.keys()):
thresholds_mask[key] = 0.5
for key in sorted(mask_categories):
thresholds_mask[key] = 0.5
for key in attributes + mask_labels:
thresholds_pred[key] = 0.5
pass
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def from_parent_instance(
categories_and_attributes=parent_instance.categories_and_attributes,
)

def describe(self) -> str:
def describe(self) -> dict:
male = (
self.attributes["Male"]
> self.categories_and_attributes.thresholds_pred["Male"],
Expand Down Expand Up @@ -193,7 +193,6 @@ def describe(self) -> str:
"has_hair": has_hair[0],
"hair_colour": hair_colour_str,
"hair_shape": hair_shape_str,
"male": male[0],
"facial_hair": facial_hair[0] != "No_Beard",
"hat": hat[0],
"glasses": glasses[0],
Expand Down Expand Up @@ -232,50 +231,79 @@ def from_parent_instance(
categories_and_attributes=parent_instance.categories_and_attributes,
)

def describe(self) -> str:
def describe(self) -> dict:
top = (
self.attributes["top"]
> self.categories_and_attributes.thresholds_pred["top"]
)
down = (
self.attributes["down"]
> self.categories_and_attributes.thresholds_pred["down"]
)
dress = (
self.attributes["dress"]
> self.categories_and_attributes.thresholds_pred["dress"]
)
outwear = (
self.attributes["outwear"]
> self.categories_and_attributes.thresholds_pred["outwear"]
)

result = {
# not in a loop for now, likely to add more logic combined with a classifier of more specific cloth classes.
"attributes": {
"top": self.attributes["top"]
> self.categories_and_attributes.thresholds_pred["top"],
"down": self.attributes["down"]
> self.categories_and_attributes.thresholds_pred["down"],
"outwear": self.attributes["outwear"]
> self.categories_and_attributes.thresholds_pred["outwear"],
"dress": self.attributes["dress"]
> self.categories_and_attributes.thresholds_pred["dress"],
"short sleeve top": self.attributes["short sleeve top"]
> self.categories_and_attributes.thresholds_pred["short sleeve top"],
"long sleeve top": self.attributes["long sleeve top"]
> self.categories_and_attributes.thresholds_pred["long sleeve top"],
"short sleeve outwear": self.attributes["short sleeve outwear"]
> self.categories_and_attributes.thresholds_pred[
"short sleeve outwear"
],
"long sleeve outwear": self.attributes["long sleeve outwear"]
> self.categories_and_attributes.thresholds_pred["long sleeve outwear"],
"vest": self.attributes["vest"]
> self.categories_and_attributes.thresholds_pred["vest"],
"sling": self.attributes["sling"]
> self.categories_and_attributes.thresholds_pred["sling"],
"outwear": self.attributes["outwear"]
> self.categories_and_attributes.thresholds_pred["outwear"],
"shorts": self.attributes["shorts"]
> self.categories_and_attributes.thresholds_pred["shorts"],
"trousers": self.attributes["trousers"]
> self.categories_and_attributes.thresholds_pred["trousers"],
"skirt": self.attributes["skirt"]
> self.categories_and_attributes.thresholds_pred["skirt"],
"short sleeve dress": self.attributes["short sleeve dress"]
> self.categories_and_attributes.thresholds_pred["short sleeve dress"],
"long sleeve dress": self.attributes["long sleeve dress"]
> self.categories_and_attributes.thresholds_pred["long sleeve dress"],
"vest dress": self.attributes["vest dress"]
> self.categories_and_attributes.thresholds_pred["vest dress"],
"sling dress": self.attributes["sling dress"]
> self.categories_and_attributes.thresholds_pred["sling dress"],
},
"attributes": {},
"description": "this descrcription will be completed if we find out it is better to do it here.",
}

for attribute in [
"short sleeve top",
"long sleeve top",
"short sleeve outwear",
"long sleeve outwear",
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress",
"sleeveless top",
]:
result["attributes"][attribute] = False

if top:
max_prob = 0.0
max_attribute = "short sleeve top"
for attribute in ["short sleeve top", "long sleeve top", "vest", "sling"]:
if self.attributes[attribute] > max_prob:
max_prob = self.attributes[attribute]
max_attribute = attribute
if max_attribute in ["vest", "sling"]:
max_attribute = "sleeveless top"
result["attributes"][max_attribute] = True


if outwear:
max_prob = 0.0
max_attribute = "short sleeve outwear"
for attribute in [
"short sleeve outwear",
"long sleeve outwear",
]:
if self.attributes[attribute] > max_prob:
max_prob = self.attributes[attribute]
max_attribute = attribute
result["attributes"][max_attribute] = True

if dress:
max_prob = 0.0
max_attribute = "short sleeve dress"
for attribute in [
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress",
]:
if self.attributes[attribute] > max_prob:
max_prob = self.attributes[attribute]
max_attribute = attribute
result["attributes"][max_attribute] = True

return result
15 changes: 13 additions & 2 deletions skills/src/lasr_skills/describe_people.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3

import os
import cv2
import rospy
import smach
Expand All @@ -26,9 +26,20 @@ def __init__(self):
)

with self:
# conditional topic and crop method for flexibility
rgb_topic = (
"/xtion/rgb/image_raw"
if "tiago" in os.environ["ROS_MASTER_URI"]
else "/camera/image_raw"
)
crop_method = (
"closest" if "tiago" in os.environ["ROS_MASTER_URI"] else "centered"
)
smach.StateMachine.add(
"GET_IMAGE",
GetCroppedImage(object_name="person", crop_method="closest"),
GetCroppedImage(
object_name="person", crop_method=crop_method, rgb_topic=rgb_topic
),
transitions={"succeeded": "CONVERT_IMAGE"},
)
smach.StateMachine.add(
Expand Down

0 comments on commit 842f631

Please sign in to comment.