diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index 99fa49b..018cb18 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -337,8 +337,19 @@ def compute_predictions_original_images(self, dataset_ids): ground_truth_annotations = self.ground_truth_annotations.get_annotations( dataset_ids ).values() + + def ensure_bbox(annotation, image): + if "bbox" not in annotation: + annotation["bbox"] = [0, 0, image.width, image.height] + return annotation + + ground_truth_with_bbox = [ + [ensure_bbox(annotation, image) for annotation in annotations] + for annotations, image in zip(ground_truth_annotations, image_id_to_image.values()) + ] + ground_truth_predictions = convert_from_ground_truth_to_second_arg( - ground_truth_annotations, self.context.dataset + ground_truth_with_bbox, self.context.dataset ) scores = compute_score( dataset_ids, diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index 94d7b1b..72ee843 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -90,24 +90,32 @@ def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str], streami return expanded_identifiers +def find_column_name(features, column_names): + return next((key for key in column_names if key in features), None) + + class HuggingFaceDataset(BaseDataset): """Interface for Hugging Face datasets with a similar API to JsonDataset.""" def __init__(self, identifier: str): + self.imgs: dict[str, dict] = {} + self.anns: dict[str, dict] = {} + self.cats: dict[str, dict] = {} + self._id_to_row_idx: dict[str, int] = {} + repo, config, split, streaming = identifier.split("@") self._streaming = streaming == "streaming" self._dataset = load_dataset(repo, config, split=split, streaming=self._streaming) - # transforms and base64 encoding require RGB mode - self._dataset.cast_column("image", DatasetImage(mode="RGB")) if self._streaming: self._dataset = self._dataset.take(HF_ROWS_TO_TAKE_STREAMING) - self.imgs: dict[str, dict] = {} - self.anns: dict[str, dict] = {} - self.cats: dict[str, dict] = {} - self._id_to_row_idx: dict[str, int] = {} self._load_data() def _load_data(self): + image_key = find_column_name(self._dataset.features, ["image", "img"]) + self._image_key = image_key + # transforms and base64 encoding require RGB mode + self._dataset.cast_column(image_key, DatasetImage(mode="RGB")) + counter = 0 def make_id(): @@ -136,21 +144,32 @@ def extract_labels(feature): if labels: self.cats = {i: {"id": i, "name": str(name)} for i, name in enumerate(labels)} + objects_key = find_column_name(self._dataset.features, ["objects"]) + + classifications_key = find_column_name( + self._dataset.features, + [ + "labels", + "label", + "classifications", + ], + ) + new_cats = set() # speed initial metadata process by not loading images if we can random access rows (not streaming) maybe_no_image = ( - self._dataset if self._streaming else self._dataset.remove_columns(["image"]) + self._dataset if self._streaming else self._dataset.remove_columns([image_key]) ) for idx, example in enumerate(maybe_no_image): id = example.get("id", example.get("image_id", idx)) if self._streaming: - self.imgs[id] = {"id": id, "image": example["image"]} + self.imgs[id] = {"id": id, "image": example[image_key]} else: self.imgs[id] = {"id": id} self._id_to_row_idx[id] = idx - if "objects" in example: - objects = example["objects"] + if objects_key: + objects = example[objects_key] if isinstance(objects, list): # Convert list of dicts to dict of lists. We want columns, not rows. cat_keys = ["category", "category_id", "label"] @@ -175,6 +194,20 @@ def extract_labels(feature): "bbox": bbox, } + if classifications_key: + classes = example[classifications_key] + if not isinstance(classes, list): + classes = [classes] + for cat_id in classes: + if cat_id not in self.cats: + new_cats.add(cat_id) + ann_id = make_id() + self.anns[ann_id] = { + "id": ann_id, + "image_id": id, + "category_id": cat_id, + } + if new_cats: max_existing_id = max(self.cats.keys(), default=0) for new_cat in new_cats: @@ -189,7 +222,7 @@ def get_image(self, id): return self.imgs[id]["image"] else: row_idx = self._id_to_row_idx[id] - return self._dataset[row_idx]["image"] + return self._dataset[row_idx][self._image_key] @lru_cache