diff --git a/scenic/dataset_lib/coco_dataset/coco_utils.py b/scenic/dataset_lib/coco_dataset/coco_utils.py index 2b620fdbf..b1f5d67f1 100644 --- a/scenic/dataset_lib/coco_dataset/coco_utils.py +++ b/scenic/dataset_lib/coco_dataset/coco_utils.py @@ -18,6 +18,10 @@ import json from typing import Dict, Optional +import immutabledict + + +ImmutableDict = immutabledict.immutabledict OBJECTS365_LABEL_MAP_PATH = ( 'scenic/dataset_lib/coco_dataset/data/objects365_class_names.txt') diff --git a/scenic/dataset_lib/coco_dataset/tests/test_coco_utils.py b/scenic/dataset_lib/coco_dataset/tests/test_coco_utils.py index 1c1bc8bae..98fe7eb6b 100644 --- a/scenic/dataset_lib/coco_dataset/tests/test_coco_utils.py +++ b/scenic/dataset_lib/coco_dataset/tests/test_coco_utils.py @@ -19,6 +19,7 @@ from scenic.dataset_lib.coco_dataset import coco_utils + class CocoUtilsTest(parameterized.TestCase): """Test COCO utils.""" @@ -42,5 +43,6 @@ def test_get_label_map_unknown(self): ValueError, lambda m: m.args == ('Unsupported TFDS name: unknown',)): coco_utils.get_label_map('unknown') + if __name__ == '__main__': absltest.main()