diff --git a/README.md b/README.md
index 91088fa..e10a698 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ nrtk-explorer --dataset ./nrtk_explorer_datasets/coco-od-2017/mini_val2017.json
or a directory loadable by the [Dataset](https://huggingface.co/docs/datasets/index) library.
You can specify multiple datasets using a space as the
separator. Example: `nrtk-explorer --dataset ../foo-dir/coco.json cppe-5`
+- `--repository` Specify an existing directory where exported datasets will be saved to and loaded from.
- `--download` Cache Hugging Face Hub datasets locally instead of streaming them.
When datasets are streamed, nrtk-explorer limits the number of loaded images.
- `--models` specify the Hugging Face Hub [object detection](https://huggingface.co/models?pipeline_tag=object-detection&library=transformers&sort=trending)
diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py
index 8f13b3c..49bd649 100644
--- a/src/nrtk_explorer/app/core.py
+++ b/src/nrtk_explorer/app/core.py
@@ -1,14 +1,22 @@
import logging
from typing import Iterable
+from pathlib import Path
+
from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library.filtering import FilterProtocol
-from nrtk_explorer.library.dataset import get_dataset, expand_hugging_face_datasets
+from nrtk_explorer.library.dataset import (
+ get_dataset,
+ expand_hugging_face_datasets,
+ discover_datasets,
+ dataset_select_options,
+)
from nrtk_explorer.library.debounce import debounce
from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.embeddings import EmbeddingsApp
+from nrtk_explorer.app.export import ExportApp
from nrtk_explorer.app.transforms import TransformsApp
from nrtk_explorer.app.filtering import FilteringApp
from nrtk_explorer.app.applet import Applet
@@ -34,6 +42,14 @@
NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds
+def dir_path(arg):
+ path = Path(arg).resolve()
+ if path.is_dir():
+ return path
+ else:
+ raise NotADirectoryError(arg)
+
+
# ---------------------------------------------------------
# Engine class
# ---------------------------------------------------------
@@ -50,6 +66,13 @@ def __init__(self, server=None):
help="Path to the JSON file describing the image dataset",
)
+ self.server.cli.add_argument(
+ "--repository",
+ default=None,
+ type=dir_path,
+ help="Path to the directory where exported datasets will be saved to.",
+ )
+
self.server.cli.add_argument(
"--download",
action="store_true",
@@ -58,11 +81,21 @@ def __init__(self, server=None):
)
known_args, _ = self.server.cli.parse_known_args()
- dataset_identifiers = expand_hugging_face_datasets(
+
+ self.state.input_datasets = expand_hugging_face_datasets(
known_args.dataset, not known_args.download
)
- self.input_paths = dataset_identifiers
- self.state.current_dataset = self.input_paths[0]
+
+ self.context.repository = known_args.repository
+ self.state.repository_datasets = [
+ str(path) for path in discover_datasets(self.context.repository)
+ ]
+
+ self.state.all_datasets = self.state.input_datasets + self.state.repository_datasets
+
+ self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)
+
+ self.state.current_dataset = self.state.all_datasets[0]
images = Images(server=self.server)
@@ -81,6 +114,10 @@ def __init__(self, server=None):
server=self.server.create_child_server(translator=filtering_translator),
)
+ self._export_app = ExportApp(
+ server=self.server.create_child_server(),
+ )
+
self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations)
self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered)
self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered)
@@ -153,7 +190,8 @@ def resample_images(self, **kwargs):
else:
selected_images = images
- self.state.dataset_ids = [str(img["id"]) for img in selected_images]
+ self.context.dataset_ids = [img["id"] for img in selected_images]
+ self.state.dataset_ids = [str(image_id) for image_id in self.context.dataset_ids]
def _build_ui(self):
extra_args = {}
@@ -163,9 +201,9 @@ def _build_ui(self):
self.ui = ui.NrtkExplorerLayout(
server=self.server,
- dataset_paths=self.input_paths,
embeddings_app=self._embeddings_app,
filtering_app=self._filtering_app,
transforms_app=self._transforms_app,
+ export_app=self._export_app,
**extra_args,
)
diff --git a/src/nrtk_explorer/app/export.py b/src/nrtk_explorer/app/export.py
new file mode 100644
index 0000000..276df47
--- /dev/null
+++ b/src/nrtk_explorer/app/export.py
@@ -0,0 +1,216 @@
+from pathlib import Path
+
+from trame.app import get_server, asynchronous
+from trame.widgets import quasar, html
+from trame.ui.quasar import QLayout
+
+from nrtk_explorer.app.applet import Applet
+from nrtk_explorer.library.dataset import (
+ discover_datasets,
+ dataset_select_options,
+)
+import nrtk_explorer.library.transforms as trans
+from nrtk_explorer.widgets.nrtk_explorer import ExportWidget
+
+
+def recursive_rmdir(path: Path):
+ if not path.is_dir():
+ return
+
+ for item in path.iterdir():
+ if item.is_file():
+ item.unlink()
+ else:
+ recursive_rmdir(item)
+
+ path.rmdir()
+
+
+class ExportApp(Applet):
+ def __init__(self, server):
+ super().__init__(server)
+
+ self.context.setdefault("repository", "")
+ self.state.setdefault("current_dataset", "")
+ self.state.setdefault("repository_datasets", [])
+ self.state.setdefault("export_status", "idle")
+ self.state.setdefault("export_progress", 0)
+
+ self.server.controller.add("on_server_ready")(self.on_server_ready)
+
+ self._ui = None
+
+ def on_server_ready(self, *args, **kwargs):
+ # Bind instance methods to state change
+ pass
+
+ def on_export_clicked(self, event):
+ self.start_export(event["name"], event["full"])
+
+ def start_export(self, name, full):
+ if self._exporting_dataset():
+ return
+
+ self._export_task = asynchronous.create_task(self.export_dataset(name, full))
+
+ async def export_dataset(self, name, full):
+ with self.state:
+ self.state.export_status = "pending"
+ self.state.export_progress = 0
+ await self.server.network_completion
+
+ try:
+ await self._export_dataset(name, full)
+ with self.state:
+ self.state.export_status = "success"
+ except Exception:
+ with self.state:
+ self.state.export_status = "fail"
+ finally:
+ with self.state:
+ self.state.export_progress = 1
+ # Update list of available datasets
+ self.state.repository_datasets = [
+ str(path) for path in discover_datasets(self.context.repository)
+ ]
+ self.state.all_datasets = (
+ self.state.input_datasets + self.state.repository_datasets
+ )
+ self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)
+ await self.server.network_completion
+
+ async def _export_dataset(self, name, full):
+ tmp_dataset_dir = self.context.repository / "tmp" / name
+ dataset_dir = self.context.repository / name
+ recursive_rmdir(tmp_dataset_dir)
+ Path.mkdir(tmp_dataset_dir, parents=True)
+
+ # Ensure the transform parameters are frozen for the duration of the task
+ transforms = []
+ for t in self.context.transforms:
+ instance = t["instance"]
+ new_instance = instance.__class__()
+ new_instance.set_parameters(instance.get_parameters())
+ transforms.append(new_instance)
+
+ # transforms = list(map(lambda t: t["instance"], self.context.transforms))
+ transform = trans.ChainedImageTransform(transforms)
+
+ dataset = self.context.dataset
+
+ if full:
+ image_ids = set(dataset.imgs.keys())
+ else:
+ image_ids = set(self.context.dataset_ids)
+
+ import kwcoco
+
+ new_dataset = kwcoco.CocoDataset()
+
+ # How often to update the progress
+ PROGRESS_UPDATE_STEP = 30
+ # Ensure a directory doesn't have too many files
+ MAX_FILES_PER_DIRECTORY = 100
+
+ def subdir_generator(max_files):
+ i = 0
+
+ while True:
+ yield str(i // max_files)
+ i += 1
+
+ subdir = subdir_generator(MAX_FILES_PER_DIRECTORY)
+
+ for i, image_id in enumerate(image_ids):
+ subdir_name = next(subdir)
+ destination_dir = tmp_dataset_dir / subdir_name
+
+ if not Path.exists(destination_dir):
+ Path.mkdir(destination_dir, parents=True)
+
+ img = dataset.get_image(image_id)
+ img.load() # Avoid OSError(24, 'Too many open files')
+ # transforms require RGB mode
+ img = img.convert("RGB") if img.mode != "RGB" else img
+
+ if img.format is not None:
+ img_format = img.format
+ else:
+ img_format = "PNG"
+
+ img_destination = destination_dir / f"{image_id}.{img_format.lower()}"
+ transformed_img = transform.execute(img)
+ transformed_img.save(img_destination, img_format)
+
+ new_dataset.add_image(img_destination, id=image_id)
+
+ if i % PROGRESS_UPDATE_STEP == 0:
+ with self.state:
+ self.state.export_progress = i / len(image_ids)
+ await self.server.network_completion
+
+ for cat in dataset.cats.values():
+ new_dataset.add_category(**cat)
+
+ for ann in dataset.anns.values():
+ if ann["image_id"] in image_ids:
+ new_dataset.add_annotation(**ann)
+
+ new_dataset.fpath = tmp_dataset_dir / f"{name}.json"
+ new_dataset.reroot()
+ new_dataset.dump()
+
+ recursive_rmdir(dataset_dir)
+ Path.rename(tmp_dataset_dir, dataset_dir)
+
+ def _exporting_dataset(self):
+ return hasattr(self, "_export_task") and not self._export_task.done()
+
+ def export_ui(self):
+ self.form_ui()
+
+ def form_ui(self):
+ with html.Div(trame_server=self.server):
+ ExportWidget(
+ current_dataset=("current_dataset",),
+ repository_datasets=("repository_datasets",),
+ export_dataset=(self.on_export_clicked, "[$event]"),
+ status=("export_status",),
+ progress=("export_progress",),
+ )
+
+ @property
+ def ui(self):
+ if self._ui is None:
+ with QLayout(self.server) as layout:
+ self._ui = layout
+
+ with quasar.QDrawer(
+ v_model=("leftDrawerOpen", True),
+ side="left",
+ elevated=True,
+ width="500",
+ ):
+ self.export_ui()
+
+ with quasar.QPageContainer():
+ with quasar.QPage():
+ with html.Div(classes="row", style="min-height: inherit;"):
+ with html.Div(classes="col q-pa-md"):
+ pass
+
+ return self._ui
+
+
+def main(server=None, *args, **kwargs):
+ server = get_server()
+ server.client_type = "vue3"
+
+ app = ExportApp(server)
+ app.ui
+
+ server.start(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/nrtk_explorer/app/ui/__init__.py b/src/nrtk_explorer/app/ui/__init__.py
index 7aa1a7c..592c4ef 100644
--- a/src/nrtk_explorer/app/ui/__init__.py
+++ b/src/nrtk_explorer/app/ui/__init__.py
@@ -1,6 +1,6 @@
from .layout import NrtkExplorerLayout
from .image_list import ImageList
-from .collapsible_card import CollapsibleCard
+from .collapsible_card import CollapsibleCard, CollapsibleCardUnslotted
def reload(m=None):
@@ -17,4 +17,5 @@ def reload(m=None):
"NrtkExplorerLayout",
"ImageList",
"CollapsibleCard",
+ "CollapsibleCardUnslotted",
]
diff --git a/src/nrtk_explorer/app/ui/collapsible_card.py b/src/nrtk_explorer/app/ui/collapsible_card.py
index 6a2bc53..1e8edef 100644
--- a/src/nrtk_explorer/app/ui/collapsible_card.py
+++ b/src/nrtk_explorer/app/ui/collapsible_card.py
@@ -29,3 +29,28 @@ def __init__(self, name=None, collapsed=False, **kwargs):
with html.Div(v_show=(name, not collapsed)):
self.slot_content = quasar.QCardSection()
self.slot_actions = quasar.QCardActions(align="right")
+
+
+class CollapsibleCardUnslotted(quasar.QCard):
+ def __init__(self, name=None, collapsed=False, **kwargs):
+ super().__init__(**kwargs)
+
+ if name is None:
+ CollapsibleCard.id_count += 1
+ name = f"is_card_open_{CollapsibleCard.id_count}"
+ self.state.client_only(name) # keep it local if not provided
+
+ with self:
+ with quasar.QCardSection():
+ with html.Div(classes="row items-center no-wrap"):
+ self.slot_title = html.Div(classes="col")
+ with html.Div(classes="col-auto"):
+ quasar.QBtn(
+ round=True,
+ flat=True,
+ dense=True,
+ click=f"{name} = !{name}",
+ icon=(f"{name} ? 'keyboard_arrow_up' : 'keyboard_arrow_down'",),
+ )
+ with quasar.QSlideTransition():
+ self.slot_collapse = html.Div(v_show=(name, not collapsed))
diff --git a/src/nrtk_explorer/app/ui/layout.py b/src/nrtk_explorer/app/ui/layout.py
index 3fc3b84..250e605 100644
--- a/src/nrtk_explorer/app/ui/layout.py
+++ b/src/nrtk_explorer/app/ui/layout.py
@@ -1,4 +1,3 @@
-from pathlib import Path
from trame.ui.quasar import QLayout
from trame.widgets import quasar
from trame.widgets import html
@@ -8,13 +7,13 @@
VERTICAL_SPLIT_DEFAULT_VALUE = 40
-def parse_dataset_dirs(datasets):
- return [{"label": Path(ds).name, "value": ds} for ds in datasets]
-
-
class NrtkDrawer(html.Div):
def __init__(
- self, dataset_paths=[], embeddings_app=None, filtering_app=None, transforms_app=None
+ self,
+ embeddings_app=None,
+ filtering_app=None,
+ transforms_app=None,
+ export_app=None,
):
super().__init__(classes="q-pa-md q-gutter-md")
@@ -27,7 +26,7 @@ def __init__(
quasar.QSelect(
label="Dataset",
v_model=("current_dataset",),
- options=(parse_dataset_dirs(dataset_paths),),
+ options=("all_datasets_options",),
filled=True,
emit_value=True,
map_options=True,
@@ -84,6 +83,13 @@ def __init__(
with card.slot_actions:
transforms_app.apply_ui()
+ # Export
+ with ui.CollapsibleCardUnslotted() as card:
+ with card.slot_title:
+ html.Span("Export Dataset", classes="text-h6")
+ with card.slot_collapse:
+ export_app.export_ui()
+
# Filters
with ui.CollapsibleCard() as card:
with card.slot_title:
@@ -126,10 +132,10 @@ def __init__(
self,
server,
reload=None,
- dataset_paths=None,
embeddings_app=None,
filtering_app=None,
transforms_app=None,
+ export_app=None,
**kwargs,
):
super().__init__(server, view="lhh LpR lff", classes="shadow-2 rounded-borders bg-grey-2")
@@ -150,10 +156,10 @@ def __init__(
) as split_drawer_main:
with split_drawer_main.slot_before:
NrtkDrawer(
- dataset_paths=dataset_paths,
embeddings_app=embeddings_app,
filtering_app=filtering_app,
transforms_app=transforms_app,
+ export_app=export_app,
)
with split_drawer_main.slot_after:
with Splitter(
diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py
index 72ee843..f19f853 100644
--- a/src/nrtk_explorer/library/dataset.py
+++ b/src/nrtk_explorer/library/dataset.py
@@ -5,7 +5,7 @@
dataset = get_dataset("path/to/dataset.json")
"""
-from typing import Sequence as SequenceType
+from typing import Sequence as SequenceType, Union
from abc import ABC, abstractmethod
import os
from functools import lru_cache
@@ -65,6 +65,25 @@ def get_image(self, id: int):
return JsonDataset(path)
+def discover_datasets(repository: Union[Path, None]) -> list[Path]:
+ datasets: list[Path] = []
+
+ if repository is None:
+ return datasets
+
+ for item in repository.iterdir():
+ if item.is_dir():
+ for file in item.iterdir():
+ if file.is_file() and file.suffix == ".json" and is_coco_dataset(str(file)):
+ datasets.append(file)
+
+ return datasets
+
+
+def dataset_select_options(datasets: list[str]):
+ return [{"label": Path(ds).name, "value": ds} for ds in datasets]
+
+
def is_coco_dataset(path: str):
if not os.path.exists(path):
return False
diff --git a/src/nrtk_explorer/widgets/nrtk_explorer.py b/src/nrtk_explorer/widgets/nrtk_explorer.py
index 86ad7e8..47f8d49 100644
--- a/src/nrtk_explorer/widgets/nrtk_explorer.py
+++ b/src/nrtk_explorer/widgets/nrtk_explorer.py
@@ -89,3 +89,20 @@ def __init__(self, **kwargs):
"update:operator",
"update:invert",
]
+
+
+class ExportWidget(HtmlElement):
+ def __init__(self, **kwargs):
+ super().__init__(
+ "export-widget",
+ **kwargs,
+ )
+ self._attr_names += [
+ ("current_dataset", "currentDataset"),
+ ("repository_datasets", "repositoryDatasets"),
+ "status",
+ "progress",
+ ]
+ self._event_names += [
+ ("export_dataset", "exportDataset"),
+ ]
diff --git a/vue-components/src/components/ExportWidget.vue b/vue-components/src/components/ExportWidget.vue
new file mode 100644
index 0000000..7d9c7b4
--- /dev/null
+++ b/vue-components/src/components/ExportWidget.vue
@@ -0,0 +1,117 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Export
+
+
+
diff --git a/vue-components/src/components/index.js b/vue-components/src/components/index.js
index 444762d..1ce2167 100644
--- a/vue-components/src/components/index.js
+++ b/vue-components/src/components/index.js
@@ -3,11 +3,13 @@ import ParamsWidget from './ParamsWidget.vue'
import TransformsWidget from './TransformsWidget.vue'
import FilterOptionsWidget from './FilterOptionsWidget.vue'
import FilterOperatorWidget from './FilterOperatorWidget.vue'
+import ExportWidget from './ExportWidget.vue'
export default {
scatterPlot: ScatterPlot,
paramsWidget: ParamsWidget,
transformsWidget: TransformsWidget,
filterOptionsWidget: FilterOptionsWidget,
- filterOperatorWidget: FilterOperatorWidget
+ filterOperatorWidget: FilterOperatorWidget,
+ exportWidget: ExportWidget
}