diff --git a/.gitignore b/.gitignore
index efeb31e..68061c3 100755
--- a/.gitignore
+++ b/.gitignore
@@ -158,4 +158,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
\ No newline at end of file
+#.idea/
+
+# mac system
+*.DS_Store
diff --git a/README.md b/README.md
index 367e2e1..86fa20f 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
The repo is the toolbox for D3
[Doc 📚]
- [HuggingFace 🤗]
+
[Paper (DOD) 📄]
[Paper (GRES) 📄]
[Awesome-DOD 🕶️]
@@ -20,25 +20,32 @@
Description Detection Dataset ($D^3$, /dikju:b/) is an attempt at creating a next-generation object detection dataset. Unlike traditional detection datasets, the class names of the objects are no longer simple nouns or noun phrases, but rather complex and descriptive, such as `a dog not being held by a leash`. For each image in the dataset, any object that matches the description is annotated. The dataset provides annotations such as bounding boxes and finely crafted instance masks. We believe it will contribute to computer vision and vision-language communities.
+
# News
-- [10/12/2023] We released an [awesome-described-object-detection](https://github.com/Charles-Xie/awesome-described-object-detection) list to collect and track related works. The paper is renamed as *Described Object Detection: Liberating Object Detection with Flexible Expressions* ([arxiv](https://arxiv.org/abs/2307.12813)).
+- [02/14/2024] Evaluation on several SOTA methods (SPHNX (the first MLLM evaluated!), G-DINO, UNINEXT, etc.) are released, together with a [leaderboard](https://github.com/shikras/d-cube/tree/main/eval_sota) for $D^3$. :fire::fire:
+
+- [10/12/2023] We released an [awesome-described-object-detection](https://github.com/Charles-Xie/awesome-described-object-detection) list to collect and track related works.
- [09/22/2023] Our DOD [paper](https://arxiv.org/abs/2307.12813) just got accepted by NeurIPS 2023! :fire:
- [07/25/2023] This toolkit is available on PyPI now. You can install this repo with `pip install ddd-dataset`.
-- [07/25/2023] The [paper preprint](https://arxiv.org/abs/2307.12813) of *Exposing the Troublemakers in Described Object Detection*, introducing the DOD task and the $D^3$ dataset, is available on arxiv. Check it out!
+- [07/25/2023] The [paper preprint](https://arxiv.org/abs/2307.12813) introducing the DOD task and the $D^3$ dataset, is available on arxiv. Check it out!
- [07/18/2023] We have released our Description Detection Dataset ($D^3$) and the first version of $D^3$ toolbox. You can download it now for your project.
- [07/14/2023] Our GRES [paper](https://arxiv.org/abs/2305.12452) has been accepted by ICCV 2023.
+
+
# Contents
- [Dataset Highlight](#task-and-dataset-highlight)
- [Download](#download)
- [Installation](#installation)
- [Usage](#usage)
+
+
# Task and Dataset Highlight
The $D^3$ dataset is meant for the Described Object Detection (DOD) task. In the image below we show the difference between Referring Expression Comprehension (REC), Object Detection/Open-Vocabulary Detection (OVD) and Described Object Detection (DOD). OVD detect object based on category name, and each category can have zero to multiple instances; REC grounds one region based on a language description, whether the object truly exits or not; DOD detect all instances on each image in the dataset, based on a flexible reference. Related works are tracked in the [awesome-DOD](https://github.com/Charles-Xie/awesome-described-object-detection) list.
@@ -47,17 +54,24 @@ The $D^3$ dataset is meant for the Described Object Detection (DOD) task. In the
For more information on the characteristics of this dataset, please refer to our paper.
+
+
# Download
Currently we host the $D^3$ dataset on cloud drives. You can download the dataset from [Google Drive](https://drive.google.com/drive/folders/11kfY12NzKPwsliLEcIYki1yUqt7PbMEi?usp=sharing) or [Baidu Pan]().
After downloading the `d3_images.zip` (images in the dataset), `d3_pkl.zip` (dataset information for this toolkit) and `d3_json.zip` (annotation for evaluation), please extract these 3 zip files to your custom `IMG_ROOT`, `PKL_PATH` and `JSON_ANNO_PATH` directory. These paths will be used when you perform inference or evaluation on this dataset.
+
+
# Installation
## Prerequisites
This toolkit requires a few python packages like `numpy` and `pycocotools`. Other packages like `matplotlib` and `opencv-python` may also be required if you want to utilize the visualization scripts.
-There are three ways to install $D^3$ toolbox, and the third one (with huggingface) is currently in the works and will be available soon.
+
+
+There are multiple ways to install $D^3$ toolbox, as listed below:
+
## Install with pip
```bash
@@ -75,10 +89,12 @@ python -m pip install .
# option 2: just put the d-cube/d_cube directory in the root directory of your local repository
```
-## Via HuggingFace Datasets 🤗
+
+
+
# Usage
Please refer to the [documentation 📚](doc.md) for more details.
@@ -93,8 +109,12 @@ all_img_info = d3.load_imgs(all_img_ids) # load images by passing a list of som
img_path = all_img_info[0]["file_name"] # obtain one image path so you can load it and inference
```
+Some frequently asked questions are answered in [this Q&A file](./qa.md).
+
# Citation
+
If you use our $D^3$ dataset, this toolbox, or otherwise find our work valuable, please cite [our paper](https://arxiv.org/abs/2307.12813):
+
```bibtex
@inproceedings{xie2023DOD,
title={Described Object Detection: Liberating Object Detection with Flexible Expressions},
@@ -111,4 +131,4 @@ If you use our $D^3$ dataset, this toolbox, or otherwise find our work valuable,
}
```
-More works related to Described Object Detection are tracked in this list: [awesome-described-object-detection](https://github.com/Charles-Xie/awesome-described-object-detection).
\ No newline at end of file
+More works related to Described Object Detection are tracked in this list: [awesome-described-object-detection](https://github.com/Charles-Xie/awesome-described-object-detection).
diff --git a/d_cube/d3.py b/d_cube/d3.py
index 5c0cba8..480c3d0 100755
--- a/d_cube/d3.py
+++ b/d_cube/d3.py
@@ -513,7 +513,7 @@ def stat_description(self, with_rev=False, inter_group=False):
num_img_sent += len(cur_sent_set)
stat_dict["num_img_sent"] = num_img_sent
- # Number of anti img-sent pair
+ # Number of absence img-sent pair
num_anti_img_sent = 0
for img_id in self.data["images"].keys():
anno_ids = self.get_anno_ids(img_ids=img_id)
diff --git a/doc.md b/doc.md
index eaeb3c6..bcee034 100644
--- a/doc.md
+++ b/doc.md
@@ -1,12 +1,17 @@
# $D^3$ Toolkit Documentation
+
## Table of Contents
- [Inference](#inference-on-d3)
- [Key Concepts](#key-concepts-for-users)
-- [Evaluation](#evaluation)
+- [Evaluation Settings](#evaluation-settings)
+- [Evaluation Code and Examples](#evaluation-code-and-examples)
- [Dataset statistics](#dataset-statistics)
+
+
+
## Inference on $D^3$
```python
@@ -22,7 +27,7 @@ img_path = all_img_info[0]["file_name"] # obtain one image path so you can load
group_ids = d3.get_group_ids(img_ids=[img_id]) # get the group ids by passing anno ids, image ids, etc.
sent_ids = d3.get_sent_ids(group_ids=group_ids) # get the sentence ids by passing image ids, group ids, etc.
sent_list = d3.load_sents(sent_ids=sent_ids)
-ref_list = [sent['raw_sent'] for sent in sent_list]
+ref_list = [sent['raw_sent'] for sent in sent_list] # list[str]
# use these language references in `ref_list` as the references to your REC/OVD/DOD model
# save the result to a JSON file
@@ -32,13 +37,15 @@ Concepts and structures of `anno`, `image`, `sent` and `group` are explained in
In [this directory](eval_sota/) we provide the inference (and evaluation) script on some existing SOTA OVD/REC methods.
+
+
### Output Format
When the inference is done, you need to save a JSON file in the format below (COCO standard output JSON form):
```json
[
{
"category_id": "int, the value of sent_id, range [1, 422]",
- "bbox": "[x1, y1, w, h], predicted by your model, same as COCO result format",
+ "bbox": "list[int], [x1, y1, w, h], predicted by your model, same as COCO result format, absolute value in the range of [w, h, w, h]",
"image_id": "int, img_id, can be 0, 1, 2, ....",
"score": "float, predicted by your model, no restriction on its absolute value range"
}
@@ -46,12 +53,10 @@ When the inference is done, you need to save a JSON file in the format below (CO
```
This JSON file should contain a list, where each item in the list is a dictionary of one detection result.
-With this JSON saved, you can evaluate the JSON in the next step. See [the evaluation step](#evaluation).
+With this JSON saved, you can evaluate the JSON in the next step. See [the evaluation step](#evaluation-code-and-examples).
-### Intra- or Inter-Group Settings
-The default evaluation protocol is the intra-group setting, where only a certain references are considerred for each image. Inter-group setting, where all references in the dataset are considerred for each image, can be easily achieved by changing `sent_ids = d3.get_sent_ids(group_ids=group_ids)` to `sent_ids = d3.get_sent_ids()`. This will use all the sentences in the dataset, rather than a few sentences in the group that this image belongs to.
## Key Concepts for Users
@@ -116,7 +121,7 @@ A Python dictionary where the keys are integers and the values are dictionaries
* `id`: an integer representing the ID of the sentence.
* `anno_id`: a list of integers representing the IDs of annotations associated with this sentence.
* `group_id`: a list of integers representing the IDs of groups associated with this sentence.
-* `is_negative`: a boolean indicating whether this sentence is anti-expression or not.
+* `is_negative`: a boolean indicating whether this sentence is *absence expression* or not. `True` means *absence expression*.
* `raw_sent`: a string representing the raw text of the sentence in English.
* `raw_sent_zh`: a string representing the raw text of the sentence in Chinese.
@@ -137,7 +142,7 @@ A Python dictionary where the keys are integers and the values are dictionaries
A Python dictionary where the keys are integers and the values are dictionaries with the following key-value pairs:
* `id`: an integer representing the ID of the group.
-* `pos_sent_id`: a list of integers representing the IDs of sentences that has referred obejct in the group.
+* `pos_sent_id`: a list of integers representing the IDs of sentences that has referred obejct in the group.
* `inner_sent_id`: a list of integers representing the IDs of sentences belonging to this group.
* `outer_sent_id`: a list of integers representing the IDs of outer-group sentences that has referred obejct in the group.
* `img_id`: a list of integers representing the IDs of images of this group.
@@ -160,9 +165,61 @@ A Python dictionary where the keys are integers and the values are dictionaries
}
```
-## Evaluation
+
+
+
+
+## Evaluation Settings
+
+
+### Intra- or Inter-Group Settings
+
+The default evaluation protocol is the intra-group setting, where only a certain references are evaluated for each image.
+
+In the $D^3$ dataset, images are collected for different groups (scenarios), and the categories (descriptions) are designed based on the scenarios. For the intra-group setting, each image are only evaluated with the descriptions from the group the image belongs to. We call this **intra-scenario setting**.
+
+Note that each category is actually annotated on each image (with positive or negative instances).
+So you can also evaluate all categories on all images, just like traditional detection datasets. We call this **inter-scenario setting**.
+This is quite challenging for the DOD task as this will produce many false positive instances on current methods.
+
+For intra-group evaluation, you should use:
+```
+sent_ids = d3.get_sent_ids(group_ids=group_ids)
+# only get the refs (sents) for the group the image belongs to, which is usually 4
+```
+
+For inter-group evaluation, change the correponding code to:
+
+```
+sent_ids = d3.get_sent_ids()
+# get all the refs in the dataset
+```
+
+This will use all the sentences in the dataset, rather than a few sentences in the group that this image belongs to.
+
+This is the only difference in the implentation and evaluation. No further code changes need to be applied.
+
+For more information, you can refer to the Section 3.4 of the DOD paper.
+
+
+### FULL, PRES and ABS
+
+FULL, PRES and ABS means the full descriptions (422 categories), presence descriptions (316 categories) and absence descriptions (106 categories).
+
+The meaning of absence descriptions are the descriptions involving the absence of some concepts, like lacking certain relationships, attributes or objects. For example, descriptions like "dog *without* leash", "person *without* helmet" and "a hat that is *not* blue" are absence ones.
+Similary, the descriptions involving *only* the presence of some concepts are presence descriptions.
+
+Most existing REC datasets have presence descriptions but few absence descriptions.
+
+For more details and the meaning of evaluating absence descriptions, please refer to Section 3.1 of the DOD paper.
+
+
+
+
+## Evaluation Code and Examples
In this part, we introduce how to evaluate the performance and get the metric values given the prediction result of a JSON file.
+
### Write a Snippet in Your Code
This is based on [cocoapi (pycocotools)](https://github.com/cocodataset/cocoapi/tree/master/PythonAPI), and is quite simple:
@@ -204,10 +261,15 @@ optional arguments:
--xyxy2xywh transform box coords from xyxy to xywh
```
-## Evaluation Examples on SOTA Methods
-See [this directory](eval_sota/) for details. More scripts for evaluating popular SOTA OVD/REC/other methods on $D^3$ will be added later.
+### Evaluation Examples on SOTA Methods
+
+See [this directory](eval_sota/) for details. We include the evaluation scripts of some methods there.
+
+
## Dataset Statistics
[A python script](scripts/get_d3_stat.py) is provided for calculating the statistics of $D^3$ or visualizing figures like histograms, word clouds, etc.
+
+The specific statistics of the dataset are available in Section 3.3 of the DOD paper.
diff --git a/eval_sota/README.md b/eval_sota/README.md
new file mode 100644
index 0000000..e72f60c
--- /dev/null
+++ b/eval_sota/README.md
@@ -0,0 +1,27 @@
+# Evaluting SOTA Methods on $D^3$
+
+## Leaderboard
+
+In this directory, we keep the scripts or github links (official or custom) to evaluate SOTA methods (REC/OVD/DOD/MLLM) on $D^3$:
+
+| Name | Paper | Original Tasks | Training Data | Evaluation Code | Intra-FULL/PRES/ABS/Inter-FULL/PRES/ABS | Source | Note |
+|:-----|:-----:|:----:|:-----:|:-----:|:-----:|:-----:|:-----:|
+| OFA-large | [OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework (ICML 2022)](https://arxiv.org/abs/2202.03052) | REC | - | - | 4.2/4.1/4.6/0.1/0.1/0.1 | [DOD paper](https://arxiv.org/abs/2307.12813) | - |
+| CORA-R50 | [CORA: Adapting CLIP for Open-Vocabulary Detection with Region Prompting and Anchor Pre-Matching (CVPR 2023)](https://openaccess.thecvf.com/content/CVPR2023/papers/Wu_CORA_Adapting_CLIP_for_Open-Vocabulary_Detection_With_Region_Prompting_and_CVPR_2023_paper.pdf) | OVD | - | - | 6.2/6.7/5.0/2.0/2.2/1.3 | [DOD paper](https://arxiv.org/abs/2307.12813) | - |
+| OWL-ViT-large | [Simple Open-Vocabulary Object Detection with Vision Transformers (ECCV 2022)](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136700714.pdf) | OVD | - | [DOD official](./owl_vit.py) | 9.6/10.7/6.4/2.5/2.9/2.1 | [DOD paper](https://arxiv.org/abs/2307.12813) | Post-processing hyper-parameters may affect the performance and the result may not exactly match the paper |
+| SPHINX-7B | [SPHINX: The Joint Mixing of Weights, Tasks, and Visual Embeddings for Multi-modal Large Language Models (arxiv 2023)](https://arxiv.org/abs/2311.07575) | **MLLM** capable of REC | - | [DOD official](./sphinx.py) | 10.6/11.4/7.9/-/-/- | DOD authors | A lot of contribution from [Jie Li](https://github.com/theFool32) |
+| GLIP-T | [Grounded Language-Image Pre-training (CVPR 2022)](https://arxiv.org/abs/2112.03857) | OVD & PG | - | - | 19.1/18.3/21.5/-/-/- | GEN paper | - |
+| UNINEXT-huge | [Universal Instance Perception as Object Discovery and Retrieval (CVPR 2023)](https://arxiv.org/abs/2303.06674v2) | OVD & REC | - | [DOD official](https://github.com/Charles-Xie/UNINEXT_D3) | 20.0/20.6/18.1/3.3/3.9/1.6 | [DOD paper](https://arxiv.org/abs/2307.12813) | - |
+| Grounding-DINO-base | [Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection (arxiv 2023)](https://arxiv.org/abs/2303.05499) | OVD & REC | - | [DOD official](./groundingdino.py) | 20.7/20.1/22.5/2.7/2.4/3.5 | [DOD paper](https://arxiv.org/abs/2307.12813) | Post-processing hyper-parameters may affect the performance and the result may not exactly match the paper |
+| OFA-DOD-base | [Described Object Detection: Liberating Object Detection with Flexible Expressions (NeurIPS 2023)](https://arxiv.org/abs/2307.12813) | DOD | - | - | 21.6/23.7/15.4/5.7/6.9/2.3 | [DOD paper](https://arxiv.org/abs/2307.12813) | - |
+| FIBER-B | [Coarse-to-Fine Vision-Language Pre-training with Fusion in the Backbone (NeurIPS 2022)](https://arxiv.org/abs/2206.07643) | OVD & REC | - | - | 22.7/21.5/26.0/-/-/- | GEN paper | - |
+| MM-Grounding-DINO | [An Open and Comprehensive Pipeline for Unified Object Grounding and Detection (arxiv 2024)](https://arxiv.org/abs/2401.02361) | DOD & OVD & REC | O365, GoldG, GRIT, V3Det | [MM-GDINO official](https://github.com/open-mmlab/mmdetection/tree/main/configs/mm_grounding_dino#zero-shot-description-detection-datasetdod) | 22.9/21.9/26.0/-/-/- | MM-GDINO paper | - |
+| GEN (FIBER-B) | [Generating Enhanced Negatives for Training Language-Based Object Detectors (arxiv 2024](https://arxiv.org/abs/2401.00094) | DOD | - | - | 26.0/25.2/28.1/-/-/- | GEN paper | Enhancement based on FIBER-B |
+| APE-large (D) | [Aligning and Prompting Everything All at Once for Universal Visual Perception (arxiv 2023)](https://arxiv.org/abs/2312.02153) | DOD & OVD & REC | COCO, LVIS, O365, OpenImages, Visual Genome, RefCOCO/+/g, SA-1B, GQA, PhraseCut, Flickr30k | [APE official](https://github.com/shenyunhang/APE) | 37.5/38.8/33.9/21.0/22.0/17.9 | APE paper | Extra training data helps for this amazing performance |
+
+
+Some extra notes:
+- Each method is currently recorded by *the variant with the highest performance* in this table, if there are multiple variants available, so it's only a leaderboard, not meant for fair comparison.
+- Methods like GLIP, FIBER, etc. are actually not evaluated on OVD benchmarks. For zero-shot eval on DOD, We currently do not distinguish between methods for OVD benchmarks and methods for ZS-OD, as long as it is verified with open-set detection capability.
+
+For other variants (e.g. for a fair comparison regarding data, backbone, etc.), please refer to the papers.
diff --git a/eval_sota/groundingdino.py b/eval_sota/groundingdino.py
new file mode 100644
index 0000000..f987925
--- /dev/null
+++ b/eval_sota/groundingdino.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+__author__ = "Chi Xie"
+__maintainer__ = "Chi Xie"
+
+# An example for how to run this script:
+# CUDA_VISIBLE_DEVICES=0
+# python groundingdino.py \
+# -c ./groundingdino/config/GroundingDINO_SwinB.cfg.py \
+# -p ./ckpt/groundingdino_swinb_cogcoor.pth \
+# -o "outputs/gdino_d3" \
+# --box_threshold 0.05 \
+# --text_threshold 0.05 \
+# --img-top1
+
+import argparse
+import json
+import os
+
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from tqdm import tqdm
+
+import groundingdino.datasets.transforms as T
+from groundingdino.models import build_model
+from groundingdino.util.slconfig import SLConfig
+from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
+from d_cube import D3
+
+
+def plot_boxes_to_image(image_pil, tgt):
+ H, W = tgt["size"]
+ boxes = tgt["boxes"]
+ labels = tgt["labels"]
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
+
+ draw = ImageDraw.Draw(image_pil)
+ mask = Image.new("L", image_pil.size, 0)
+ mask_draw = ImageDraw.Draw(mask)
+
+ # draw boxes and masks
+ for box, label in zip(boxes, labels):
+ # from 0..1 to 0..W, 0..H
+ box = box * torch.Tensor([W, H, W, H])
+ # from xywh to xyxy
+ box[:2] -= box[2:] / 2
+ box[2:] += box[:2]
+ # random color
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
+ # draw
+ x0, y0, x1, y1 = box
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
+
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
+ # draw.text((x0, y0), str(label), fill=color)
+
+ font = ImageFont.load_default()
+ if hasattr(font, "getbbox"):
+ bbox = draw.textbbox((x0, y0), str(label), font)
+ else:
+ w, h = draw.textsize(str(label), font)
+ bbox = (x0, y0, w + x0, y0 + h)
+ # bbox = draw.textbbox((x0, y0), str(label))
+ draw.rectangle(bbox, fill=color)
+ draw.text((x0, y0), str(label), fill="white")
+
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
+ return image_pil, mask
+
+
+def load_image(image_path):
+ # load image
+ image_pil = Image.open(image_path).convert("RGB") # load image
+
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image, _ = transform(image_pil, None) # 3, h, w
+ return image_pil, image
+
+
+def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
+ args = SLConfig.fromfile(model_config_path)
+ args.device = "cuda" if not cpu_only else "cpu"
+ model = build_model(args)
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
+ print(load_res)
+ _ = model.eval()
+ return model
+
+
+def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
+ caption = caption.lower()
+ caption = caption.strip()
+ if not caption.endswith("."):
+ caption = caption + "."
+ device = "cuda" if not cpu_only else "cpu"
+ model = model.to(device)
+ image = image.to(device)
+ with torch.no_grad():
+ outputs = model(image[None], captions=[caption])
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
+ logits.shape[0]
+
+ # filter output
+ logits_filt = logits.clone()
+ boxes_filt = boxes.clone()
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
+ logits_filt.shape[0]
+
+ # get phrase
+ tokenlizer = model.tokenizer
+ tokenized = tokenlizer(caption)
+ # build pred
+ pred_phrases = []
+ logits_list = []
+ for logit, box in zip(logits_filt, boxes_filt):
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
+ logits_list.append(logit.max().item())
+ if with_logits:
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
+ else:
+ pred_phrases.append(pred_phrase)
+
+ return boxes_filt, pred_phrases, logits_list
+
+
+def get_dataset_iter(coco):
+ img_ids = coco.get_img_ids()
+ for img_id in img_ids:
+ img_info = coco.load_imgs(img_id)[0]
+ file_name = img_info["file_name"]
+ img_path = os.path.join(IMG_ROOT, file_name)
+ yield img_id, img_path
+
+
+def eval_on_d3(pred_path, mode="pn"):
+ assert mode in ("pn", "p", "n")
+ if mode == "pn":
+ gt_path = os.path.join(JSON_ANNO_PATH, "d3_full_annotations.json")
+ elif mode == "p":
+ gt_path = os.path.join(JSON_ANNO_PATH, "d3_pres_annotations.json")
+ else:
+ gt_path = os.path.join(JSON_ANNO_PATH, "d3_abs_annotations.json")
+ coco = COCO(gt_path)
+ d3_res = coco.loadRes(pred_path)
+ cocoEval = COCOeval(coco, d3_res, "bbox")
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+
+ # comment the following if u only need intra/inter map for full/pres/abs
+ # ===================== uncomment this if u need detailed analysis =====================
+ # aps = cocoEval.eval["precision"][:, :, :, 0, -1]
+ # category_ids = coco.getCatIds()
+ # category_names = [cat["name"] for cat in coco.loadCats(category_ids)]
+
+ # aps_lens = defaultdict(list)
+ # counter_lens = defaultdict(int)
+ # for i in range(len(category_names)):
+ # ap = aps[:, :, i]
+ # ap_value = ap[ap > -1].mean()
+ # if not np.isnan(ap_value):
+ # len_ref = len(category_names[i].split(" "))
+ # aps_lens[len_ref].append(ap_value)
+ # counter_lens[len_ref] += 1
+
+ # ap_sum_short = sum([sum(aps_lens[i]) for i in range(0, 4)])
+ # ap_sum_mid = sum([sum(aps_lens[i]) for i in range(4, 7)])
+ # ap_sum_long = sum([sum(aps_lens[i]) for i in range(7, 10)])
+ # ap_sum_very_long = sum(
+ # [sum(aps_lens[i]) for i in range(10, max(counter_lens.keys()) + 1)]
+ # )
+ # c_sum_short = sum([counter_lens[i] for i in range(1, 4)])
+ # c_sum_mid = sum([counter_lens[i] for i in range(4, 7)])
+ # c_sum_long = sum([counter_lens[i] for i in range(7, 10)])
+ # c_sum_very_long = sum(
+ # [counter_lens[i] for i in range(10, max(counter_lens.keys()) + 1)]
+ # )
+ # map_short = ap_sum_short / c_sum_short
+ # map_mid = ap_sum_mid / c_sum_mid
+ # map_long = ap_sum_long / c_sum_long
+ # map_very_long = ap_sum_very_long / c_sum_very_long
+ # print(
+ # f"mAP over reference length: short - {map_short:.4f}, mid - {map_mid:.4f}, long - {map_long:.4f}, very long - {map_very_long:.4f}"
+ # )
+ # ===================== uncomment this if u need detailed analysis =====================
+
+
+def inference_on_d3(data_iter, model, args, box_threshold, text_threshold):
+ pred = []
+ for idx, (img_id, image_path) in enumerate(tqdm(data_iter)):
+ # load image
+ image_pil, image = load_image(image_path)
+ size = image_pil.size
+ W, H = size
+
+ group_ids = d3.get_group_ids(img_ids=[img_id])
+ sent_ids = d3.get_sent_ids(group_ids=group_ids)
+ sent_list = d3.load_sents(sent_ids=sent_ids)
+ text_list = [sent['raw_sent'] for sent in sent_list]
+
+ for sent_id, text_prompt in zip(sent_ids, text_list):
+ # run model
+ boxes_filt, pred_phrases, logit_list = get_grounding_output(
+ model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, with_logits=False,
+ )
+ if args.vis:
+ pred_dict = {
+ "boxes": boxes_filt, # [x_center, y_center, w, h]
+ "size": [size[1], size[0]],
+ "labels": [f"{phrase}({str(logit)[:4]})" for phrase, logit in zip(pred_phrases, logit_list)],
+ }
+ image_with_box = plot_boxes_to_image(image_pil.copy(), pred_dict)[0]
+ image_with_box.save(os.path.join(output_dir, f"{img_id}_{text_prompt}.jpg"))
+ if not logit_list:
+ continue
+ if args.img_top1:
+ max_score_idx = logit_list.index(max(logit_list))
+ bboxes, phrases, logits = [boxes_filt[max_score_idx]], [pred_phrases[max_score_idx]], [logit_list[max_score_idx]]
+ else:
+ bboxes, phrases, logits = boxes_filt, pred_phrases, logit_list
+ for box, phrase, logit in zip(bboxes, phrases, logits):
+ if len(phrase) > args.overlap_percent * len(text_prompt) or phrase == text_prompt:
+ x1, y1, w, h = box.tolist()
+ x0, y0 = x1 - w / 2, y1 - h / 2
+ pred_item = {
+ "image_id": img_id,
+ "category_id": sent_id,
+ "bbox": [x0 * W, y0 * H, w * W, h * H],
+ "score": float(logit),
+ }
+ pred.append(pred_item)
+
+ return pred
+
+
+if __name__ == "__main__":
+ IMG_ROOT = None # set here
+ JSON_ANNO_PATH = None # set here
+ PKL_ANNO_PATH = None # set here
+ assert IMG_ROOT is not None, "Please set IMG_ROOT in the script first"
+ assert JSON_ANNO_PATH is not None, "Please set JSON_ANNO_PATH in the script first"
+ assert PKL_ANNO_PATH is not None, "Please set PKL_ANNO_PATH in the script first"
+
+ d3 = D3(IMG_ROOT, PKL_ANNO_PATH)
+
+ parser = argparse.ArgumentParser("Grounding DINO evaluation on D-cube (https://arxiv.org/abs/2307.12813)", add_help=True)
+ parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
+ parser.add_argument(
+ "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
+ )
+ # parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
+ # parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
+ parser.add_argument(
+ "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
+ )
+ parser.add_argument("--vis", action="store_true", help="visualization on D3")
+
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
+
+ parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
+ parser.add_argument("--img-top1", action="store_true", help="select only the box with top max score")
+ # parser.add_argument("--overlap-percent", type=float, default=1.0, help="overlapping percentage between input prompt and output label")
+ # this overlapping percentage denotes an additional post-processing technique we designed. if you turn this on, you may get higher performance by tuning this parameter.
+ args = parser.parse_args()
+ args.overlap_percent = 1 # by default, we do not use this technique.
+ print(args)
+
+ # cfg
+ config_file = args.config_file # change the path of the model config file
+ checkpoint_path = args.checkpoint_path # change the path of the model
+ # image_path = args.image_path
+ # text_prompt = args.text_prompt
+ output_dir = args.output_dir
+ box_threshold = args.box_threshold
+ text_threshold = args.text_threshold
+
+ # make dir
+ os.makedirs(output_dir, exist_ok=True)
+ # load model
+ model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
+
+ data_iter = get_dataset_iter(d3)
+
+ pred = inference_on_d3(data_iter, model, args, box_threshold=box_threshold, text_threshold=text_threshold)
+
+ pred_path = os.path.join(output_dir, f"prediction.json")
+ with open(pred_path, "w") as f_:
+ json.dump(pred, f_)
+ eval_on_d3(pred_path, mode='pn')
+ eval_on_d3(pred_path, mode='p')
+ eval_on_d3(pred_path, mode='n')
diff --git a/eval_sota/owl_vit.py b/eval_sota/owl_vit.py
index d8ccf78..bc28f6f 100755
--- a/eval_sota/owl_vit.py
+++ b/eval_sota/owl_vit.py
@@ -1,24 +1,32 @@
import json
import os
-import time
-import numpy as np
from collections import defaultdict
-import logging
-import torch
+from tqdm import tqdm
from PIL import Image
+import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
-from transformers import OwlViTProcessor, OwlViTForObjectDetection
-from PIL import Image
import torch
+from transformers import OwlViTProcessor, OwlViTForObjectDetection
from d_cube import D3
+def write_json(json_path, json_data):
+ with open(json_path, "w") as f_:
+ json.dump(json_data, f_)
+
+
+def read_json(json_path):
+ with open(json_path, "r") as f_:
+ json_data = json.load(f_)
+ return json_data
+
+
def load_image_general(image_path):
image_pil = Image.open(image_path)
- return image_pil, image_pil
+ return image_pil
def get_prediction(model, image, captions, cpu_only=False):
@@ -36,9 +44,16 @@ def get_prediction(model, image, captions, cpu_only=False):
outputs = model(**inputs)
target_size = torch.Tensor([image.size[::-1]]).to(device)
results = processor.post_process_object_detection(
- outputs=outputs, target_sizes=target_size, threshold=0.05
+ outputs=outputs, target_sizes=target_size, threshold=0.1
+ # the post precessing threshold will affect the performance obviously
+ # you may tune it to get better performance, e.g., 0.05
)
- return results
+ boxes, scores, labels = (
+ results[0]["boxes"],
+ results[0]["scores"],
+ results[0]["labels"],
+ )
+ return boxes, scores, labels
def get_dataset_iter(coco):
@@ -65,83 +80,85 @@ def eval_on_d3(pred_path, mode="pn"):
cocoEval.accumulate()
cocoEval.summarize()
- aps = cocoEval.eval["precision"][:, :, :, 0, -1]
- category_ids = coco.getCatIds()
- category_names = [cat["name"] for cat in coco.loadCats(category_ids)]
-
- aps_lens = defaultdict(list)
- counter_lens = defaultdict(int)
- for i in range(len(category_names)):
- ap = aps[:, :, i]
- ap_value = ap[ap > -1].mean()
- if not np.isnan(ap_value):
- len_ref = len(category_names[i].split(" "))
- aps_lens[len_ref].append(ap_value)
- counter_lens[len_ref] += 1
-
- ap_sum_short = sum([sum(aps_lens[i]) for i in range(0, 4)])
- ap_sum_mid = sum([sum(aps_lens[i]) for i in range(4, 7)])
- ap_sum_long = sum([sum(aps_lens[i]) for i in range(7, 10)])
- ap_sum_very_long = sum(
- [sum(aps_lens[i]) for i in range(10, max(counter_lens.keys()) + 1)]
- )
- c_sum_short = sum([counter_lens[i] for i in range(1, 4)])
- c_sum_mid = sum([counter_lens[i] for i in range(4, 7)])
- c_sum_long = sum([counter_lens[i] for i in range(7, 10)])
- c_sum_very_long = sum(
- [counter_lens[i] for i in range(10, max(counter_lens.keys()) + 1)]
- )
- map_short = ap_sum_short / c_sum_short
- map_mid = ap_sum_mid / c_sum_mid
- map_long = ap_sum_long / c_sum_long
- map_very_long = ap_sum_very_long / c_sum_very_long
- print(
- f"mAP over reference length: short - {map_short:.4f}, mid - {map_mid:.4f}, long - {map_long:.4f}, very long - {map_very_long:.4f}"
- )
+ # comment the following if u only need intra/inter map for full/pres/abs
+ # ===================== uncomment this if u need detailed analysis =====================
+ # aps = cocoEval.eval["precision"][:, :, :, 0, -1]
+ # category_ids = coco.getCatIds()
+ # category_names = [cat["name"] for cat in coco.loadCats(category_ids)]
+
+ # aps_lens = defaultdict(list)
+ # counter_lens = defaultdict(int)
+ # for i in range(len(category_names)):
+ # ap = aps[:, :, i]
+ # ap_value = ap[ap > -1].mean()
+ # if not np.isnan(ap_value):
+ # len_ref = len(category_names[i].split(" "))
+ # aps_lens[len_ref].append(ap_value)
+ # counter_lens[len_ref] += 1
+
+ # ap_sum_short = sum([sum(aps_lens[i]) for i in range(0, 4)])
+ # ap_sum_mid = sum([sum(aps_lens[i]) for i in range(4, 7)])
+ # ap_sum_long = sum([sum(aps_lens[i]) for i in range(7, 10)])
+ # ap_sum_very_long = sum(
+ # [sum(aps_lens[i]) for i in range(10, max(counter_lens.keys()) + 1)]
+ # )
+ # c_sum_short = sum([counter_lens[i] for i in range(1, 4)])
+ # c_sum_mid = sum([counter_lens[i] for i in range(4, 7)])
+ # c_sum_long = sum([counter_lens[i] for i in range(7, 10)])
+ # c_sum_very_long = sum(
+ # [counter_lens[i] for i in range(10, max(counter_lens.keys()) + 1)]
+ # )
+ # map_short = ap_sum_short / c_sum_short
+ # map_mid = ap_sum_mid / c_sum_mid
+ # map_long = ap_sum_long / c_sum_long
+ # map_very_long = ap_sum_very_long / c_sum_very_long
+ # print(
+ # f"mAP over reference length: short - {map_short:.4f}, mid - {map_mid:.4f}, long - {map_long:.4f}, very long - {map_very_long:.4f}"
+ # )
+ # ===================== uncomment this if u need detailed analysis =====================
def inference_on_d3(data_iter, model):
pred = []
error = []
- for idx, (img_id, image_path) in enumerate(data_iter):
- logging.critical(idx)
- logging.critical(time.asctime(time.localtime(time.time())))
- image_pil, image = load_image_general(image_path)
+ for img_id, image_path in tqdm(data_iter):
+ image = load_image_general(image_path)
- # group
+ # ==================================== intra-group setting ====================================
+ # each image is evaluated with the categories in its group (usually 4)
group_ids = d3.get_group_ids(img_ids=[img_id])
sent_ids = d3.get_sent_ids(group_ids=group_ids)
+ # ==================================== intra-group setting ====================================
+ # ==================================== inter-group setting ====================================
+ # each image is evaluated with all categories in the dataset (422 for the first version of the dataset)
+ # sent_ids = d3.get_sent_ids()
+ # ==================================== inter-group setting ====================================
sent_list = d3.load_sents(sent_ids=sent_ids)
text_list = [sent["raw_sent"] for sent in sent_list]
try:
- results = get_prediction(model, image, text_list, cpu_only=False)
- i = 0
- boxes, scores, labels = (
- results[i]["boxes"],
- results[i]["scores"],
- results[i]["labels"],
- )
+ boxes, scores, labels = get_prediction(model, image, text_list, cpu_only=False)
for box, score, label in zip(boxes, scores, labels):
pred_item = {
"image_id": img_id,
"category_id": sent_ids[label],
- "bbox": box.tolist(),
+ "bbox": convert_to_xywh(box.tolist()), # use xywh
"score": float(score),
}
- pred.append(pred_item)
+ pred.append(pred_item) # the output to be saved to JSON.
except:
print("error!!!")
return pred, error
-def convert_to_xywh(x1, y1, x2, y2):
+def convert_to_xywh(bbox_xyxy):
"""
- Convert top-left and bottom-right corner coordinates to x,y,width,height format.
+ Convert top-left and bottom-right corner coordinates to [x, y, width, height] format.
"""
+ x1, y1, x2, y2 = bbox_xyxy
width = x2 - x1
height = y2 - y1
- return x1, y1, width, height
+ return [x1, y1, width, height]
if __name__ == "__main__":
@@ -151,34 +168,25 @@ def convert_to_xywh(x1, y1, x2, y2):
assert IMG_ROOT is not None, "Please set IMG_ROOT in the script first"
assert JSON_ANNO_PATH is not None, "Please set JSON_ANNO_PATH in the script first"
assert PKL_ANNO_PATH is not None, "Please set PKL_ANNO_PATH in the script first"
+
d3 = D3(IMG_ROOT, PKL_ANNO_PATH)
output_dir = "ovd/owlvit/"
os.makedirs(output_dir, exist_ok=True)
- # model predicting
+ # model prediction
processor = OwlViTProcessor.from_pretrained("owl-vit")
model = OwlViTForObjectDetection.from_pretrained("owl-vit")
data_iter = get_dataset_iter(d3)
pred, error = inference_on_d3(data_iter, model)
- pred_path = os.path.join(output_dir, f"eval_d3.json")
+ pred_path = os.path.join(output_dir, f"prediction.json")
pred_path_error = os.path.join(output_dir, "error.json")
-
- with open(pred_path, "w") as f_:
- json.dump(pred, f_)
- with open(pred_path_error, "w") as f2:
- json.dump(error, f2)
-
- # change to xywh format of bbox
- with open(pred_path, "r") as f_:
- res = json.load(f_)
- for item in res:
- item["bbox"] = convert_to_xywh(*item["bbox"])
- res_path = pred_path.replace(".json", ".xywh.json")
- with open(res_path, "w") as f_w:
- json.dump(res, f_w)
-
- eval_on_d3(res_path, mode="pn")
- eval_on_d3(res_path, mode="p")
- eval_on_d3(res_path, mode="n")
+ write_json(pred_path, pred)
+ write_json(pred_path_error, error)
+ # see https://github.com/shikras/d-cube/blob/main/doc.md#output-format for the output format
+ # the output format is identical to COCO.
+
+ eval_on_d3(pred_path, mode="pn") # the FULL setting
+ eval_on_d3(pred_path, mode="p") # the PRES setting
+ eval_on_d3(pred_path, mode="n") # the ABS setting
diff --git a/eval_sota/sphinx.py b/eval_sota/sphinx.py
new file mode 100644
index 0000000..1d9a3e4
--- /dev/null
+++ b/eval_sota/sphinx.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+__author__ = "Chi Xie and Jie Li"
+__maintainer__ = "Chi Xie"
+
+import json
+import os
+from collections import defaultdict
+import re
+
+from PIL import Image
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+
+from d_cube import D3
+
+
+def write_json(json_path, json_data):
+ with open(json_path, "w") as f_:
+ json.dump(json_data, f_)
+
+
+def read_json(json_path):
+ with open(json_path, "r") as f_:
+ json_data = json.load(f_)
+ return json_data
+
+
+def load_image_general(image_path):
+ image_pil = Image.open(image_path)
+ return image_pil
+
+
+def extract_boxes(input_string):
+ # if input_string.startswith("None"):
+ # return []
+ # Define the pattern using regular expression
+ pattern = r'\[([\d.,; ]+)\]'
+
+ # Search for the pattern in the input string
+ match = re.search(pattern, input_string)
+
+ # If a match is found, extract and return the boxes as a list
+ if match:
+ boxes_str = match.group(1)
+ boxes_list = [list(map(float, box.split(','))) for box in boxes_str.split(';')]
+ return boxes_list
+ else:
+ return []
+
+
+def get_prediction(mllm_res, image, captions, cpu_only=False):
+ boxes, scores, labels = [], [], []
+ width, height = image.size
+ for idx, res_item in enumerate(mllm_res):
+ boxes_list = extract_boxes(res_item["answer"])
+ for bbox in boxes_list:
+ bbox_rescaled = get_true_bbox(image.size, bbox)
+ boxes.append(bbox_rescaled)
+ scores.append(1.0)
+ labels.append(idx)
+ return boxes, scores, labels
+
+
+def get_dataset_iter(coco):
+ img_ids = coco.get_img_ids()
+ for img_id in img_ids:
+ img_info = coco.load_imgs(img_id)[0]
+ file_name = img_info["file_name"]
+ img_path = os.path.join(IMG_ROOT, file_name)
+ yield img_id, file_name, img_path
+
+
+def eval_on_d3(pred_path, mode="pn"):
+ assert mode in ("pn", "p", "n")
+ if mode == "pn":
+ gt_path = os.path.join(JSON_ANNO_PATH, "d3_full_annotations.json")
+ elif mode == "p":
+ gt_path = os.path.join(JSON_ANNO_PATH, "d3_pres_annotations.json")
+ else:
+ gt_path = os.path.join(JSON_ANNO_PATH, "d3_abs_annotations.json")
+ coco = COCO(gt_path)
+ d3_res = coco.loadRes(pred_path)
+ cocoEval = COCOeval(coco, d3_res, "bbox")
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+
+
+def group_sphinx_res_by_img(inference_res):
+ inference_res_by_img = defaultdict(list)
+ for res_item in inference_res:
+ img_path = "/".join(res_item["image_path"].split("/")[-2:])
+ inference_res_by_img[img_path].append(res_item)
+ inference_res_by_img = dict(inference_res_by_img)
+ return inference_res_by_img
+
+
+def get_true_bbox(img_size, bbox):
+ width, height = img_size
+ max_edge = max(height, width)
+ bbox = [v * max_edge for v in bbox]
+ diff = abs(width - height) // 2
+ if height < width:
+ bbox[1] -= diff
+ bbox[3] -= diff
+ else:
+ bbox[0] -= diff
+ bbox[2] -= diff
+ return bbox
+
+
+def inference_on_d3(data_iter, inference_res):
+ pred = []
+ inf_res_by_img = group_sphinx_res_by_img(inference_res)
+ for idx, (img_id, img_name, img_path) in enumerate(data_iter):
+ image = load_image_general(img_path)
+
+ # ==================================== intra-group setting ====================================
+ # each image is evaluated with the categories in its group (usually 4)
+ group_ids = d3.get_group_ids(img_ids=[img_id])
+ sent_ids = d3.get_sent_ids(group_ids=group_ids)
+ # ==================================== intra-group setting ====================================
+ # ==================================== inter-group setting ====================================
+ # each image is evaluated with all categories in the dataset (422 for the first version of the dataset)
+ # sent_ids = d3.get_sent_ids()
+ # ==================================== inter-group setting ====================================
+ sent_list = d3.load_sents(sent_ids=sent_ids)
+ text_list = [sent["raw_sent"] for sent in sent_list]
+
+ boxes, scores, labels = get_prediction(inf_res_by_img[img_name], image, text_list, cpu_only=False)
+ for box, score, label in zip(boxes, scores, labels):
+ pred_item = {
+ "image_id": img_id,
+ "category_id": sent_ids[label],
+ "bbox": convert_to_xywh(box), # use xywh
+ "score": float(score),
+ }
+ pred.append(pred_item) # the output to be saved to JSON.
+ return pred
+
+
+def convert_to_xywh(bbox_xyxy):
+ """
+ Convert top-left and bottom-right corner coordinates to [x, y, width, height] format.
+ """
+ x1, y1, x2, y2 = bbox_xyxy
+ width = x2 - x1
+ height = y2 - y1
+ return [x1, y1, width, height]
+
+
+if __name__ == "__main__":
+ IMG_ROOT = None # set here
+ JSON_ANNO_PATH = None # set here
+ PKL_ANNO_PATH = None # set here
+ # ============================== SPHINX inference result file ===============
+ SPHINX_INFERENCE_RES_PATH = None
+ # You can download the SPHINX d3 inference result example from:
+ # https://github.com/shikras/d-cube/files/14276682/sphinx_d3_result.json
+ # For the inference process, please refer to SPHINX official repo (https://github.com/Alpha-VLLM/LLaMA2-Accessory)
+ # the prompts we used are available in this JSON file
+ # Thanks for the contribution from Jie Li (https://github.com/theFool32)
+ # ============================== SPHINX inference result file ===============
+ assert IMG_ROOT is not None, "Please set IMG_ROOT in the script first"
+ assert JSON_ANNO_PATH is not None, "Please set JSON_ANNO_PATH in the script first"
+ assert PKL_ANNO_PATH is not None, "Please set PKL_ANNO_PATH in the script first"
+
+ d3 = D3(IMG_ROOT, PKL_ANNO_PATH)
+
+ output_dir = "mllm/sphinx/" # or whatever you prefer
+ inference_res = read_json(SPHINX_INFERENCE_RES_PATH)
+
+ # model prediction
+ data_iter = get_dataset_iter(d3)
+ pred = inference_on_d3(data_iter, inference_res)
+
+ pred_path = os.path.join(output_dir, f"prediction.json")
+ write_json(pred_path, pred)
+ # see https://github.com/shikras/d-cube/blob/main/doc.md#output-format for the output format
+ # the output format is identical to COCO.
+
+ eval_on_d3(pred_path, mode="pn") # the FULL setting
+ eval_on_d3(pred_path, mode="p") # the PRES setting
+ eval_on_d3(pred_path, mode="n") # the ABS setting
diff --git a/qa.md b/qa.md
new file mode 100644
index 0000000..dc65758
--- /dev/null
+++ b/qa.md
@@ -0,0 +1,24 @@
+# Frequently Asked Questions
+
+Q:
+What's the difference between Intra-Group and Inter-Group setting in [the DOD paper](https://arxiv.org/abs/2307.12813), and how to set them?
+
+A:
+Please see [this explanation in the document](./doc.md#intra--or-inter-group-settings).
+
+
+
+Q:
+What's the meaning of and difference between FULL, PRES, and ABS?
+
+A:
+Please see [this explanation in the document](./doc.md#full-pres-and-abs).
+
+
+
+Q:
+How do I perform a visualization of ground truth or prediction on a image?
+
+A:
+You can use `d3.get_anno_ids` function and pass the `img_id` you choose as parameter to get the annotation ids for a image.
+After this, you can obtain the annotation details (class ids, bboxes) with `d3.load_annos`.
diff --git a/scripts/eval_and_analysis_json.py b/scripts/eval_and_analysis_json.py
index e5471a5..beba3bf 100755
--- a/scripts/eval_and_analysis_json.py
+++ b/scripts/eval_and_analysis_json.py
@@ -144,9 +144,9 @@ def transform_json_boxes(pred_path):
if __name__ == "__main__":
- IMG_ROOT = None # set here
- JSON_ANNO_PATH = None # set here
- PKL_ANNO_PATH = None # set here
+ IMG_ROOT = "/Users/xiechi/Development/DOD/d3_dataset/d3_release/"
+ JSON_ANNO_PATH = "/Users/xiechi/Development/DOD/d3_dataset/d3_release/d3_json/"
+ PKL_ANNO_PATH = "/Users/xiechi/Development/DOD/d3_dataset/d3_release/d3_pkl/"
assert IMG_ROOT is not None, "Please set IMG_ROOT in the script first"
assert JSON_ANNO_PATH is not None, "Please set JSON_ANNO_PATH in the script first"
assert PKL_ANNO_PATH is not None, "Please set PKL_ANNO_PATH in the script first"
diff --git a/setup.py b/setup.py
index 1a4d502..9d9fc79 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
setuptools.setup(
name='ddd-dataset',
- version='0.1.0',
+ version='0.1.1',
author='Chi Xie',
author_email='chixie.personal@gmail.com',
description='Toolkit for Description Detection Dataset ($D^3$)',