diff --git a/README.md b/README.md index 91088fa..e10a698 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ nrtk-explorer --dataset ./nrtk_explorer_datasets/coco-od-2017/mini_val2017.json or a directory loadable by the [Dataset](https://huggingface.co/docs/datasets/index) library. You can specify multiple datasets using a space as the separator. Example: `nrtk-explorer --dataset ../foo-dir/coco.json cppe-5` +- `--repository` Specify an existing directory where exported datasets will be saved to and loaded from. - `--download` Cache Hugging Face Hub datasets locally instead of streaming them. When datasets are streamed, nrtk-explorer limits the number of loaded images. - `--models` specify the Hugging Face Hub [object detection](https://huggingface.co/models?pipeline_tag=object-detection&library=transformers&sort=trending) diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index 8f13b3c..49bd649 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -1,14 +1,22 @@ import logging from typing import Iterable +from pathlib import Path + from trame.widgets import html from trame_server.utils.namespace import Translator from nrtk_explorer.library.filtering import FilterProtocol -from nrtk_explorer.library.dataset import get_dataset, expand_hugging_face_datasets +from nrtk_explorer.library.dataset import ( + get_dataset, + expand_hugging_face_datasets, + discover_datasets, + dataset_select_options, +) from nrtk_explorer.library.debounce import debounce from nrtk_explorer.app.images.images import Images from nrtk_explorer.app.embeddings import EmbeddingsApp +from nrtk_explorer.app.export import ExportApp from nrtk_explorer.app.transforms import TransformsApp from nrtk_explorer.app.filtering import FilteringApp from nrtk_explorer.app.applet import Applet @@ -34,6 +42,14 @@ NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds +def dir_path(arg): + path = Path(arg).resolve() + if path.is_dir(): + return path + else: + raise NotADirectoryError(arg) + + # --------------------------------------------------------- # Engine class # --------------------------------------------------------- @@ -50,6 +66,13 @@ def __init__(self, server=None): help="Path to the JSON file describing the image dataset", ) + self.server.cli.add_argument( + "--repository", + default=None, + type=dir_path, + help="Path to the directory where exported datasets will be saved to.", + ) + self.server.cli.add_argument( "--download", action="store_true", @@ -58,11 +81,21 @@ def __init__(self, server=None): ) known_args, _ = self.server.cli.parse_known_args() - dataset_identifiers = expand_hugging_face_datasets( + + self.state.input_datasets = expand_hugging_face_datasets( known_args.dataset, not known_args.download ) - self.input_paths = dataset_identifiers - self.state.current_dataset = self.input_paths[0] + + self.context.repository = known_args.repository + self.state.repository_datasets = [ + str(path) for path in discover_datasets(self.context.repository) + ] + + self.state.all_datasets = self.state.input_datasets + self.state.repository_datasets + + self.state.all_datasets_options = dataset_select_options(self.state.all_datasets) + + self.state.current_dataset = self.state.all_datasets[0] images = Images(server=self.server) @@ -81,6 +114,10 @@ def __init__(self, server=None): server=self.server.create_child_server(translator=filtering_translator), ) + self._export_app = ExportApp( + server=self.server.create_child_server(), + ) + self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations) self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered) self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered) @@ -153,7 +190,8 @@ def resample_images(self, **kwargs): else: selected_images = images - self.state.dataset_ids = [str(img["id"]) for img in selected_images] + self.context.dataset_ids = [img["id"] for img in selected_images] + self.state.dataset_ids = [str(image_id) for image_id in self.context.dataset_ids] def _build_ui(self): extra_args = {} @@ -163,9 +201,9 @@ def _build_ui(self): self.ui = ui.NrtkExplorerLayout( server=self.server, - dataset_paths=self.input_paths, embeddings_app=self._embeddings_app, filtering_app=self._filtering_app, transforms_app=self._transforms_app, + export_app=self._export_app, **extra_args, ) diff --git a/src/nrtk_explorer/app/export.py b/src/nrtk_explorer/app/export.py new file mode 100644 index 0000000..276df47 --- /dev/null +++ b/src/nrtk_explorer/app/export.py @@ -0,0 +1,216 @@ +from pathlib import Path + +from trame.app import get_server, asynchronous +from trame.widgets import quasar, html +from trame.ui.quasar import QLayout + +from nrtk_explorer.app.applet import Applet +from nrtk_explorer.library.dataset import ( + discover_datasets, + dataset_select_options, +) +import nrtk_explorer.library.transforms as trans +from nrtk_explorer.widgets.nrtk_explorer import ExportWidget + + +def recursive_rmdir(path: Path): + if not path.is_dir(): + return + + for item in path.iterdir(): + if item.is_file(): + item.unlink() + else: + recursive_rmdir(item) + + path.rmdir() + + +class ExportApp(Applet): + def __init__(self, server): + super().__init__(server) + + self.context.setdefault("repository", "") + self.state.setdefault("current_dataset", "") + self.state.setdefault("repository_datasets", []) + self.state.setdefault("export_status", "idle") + self.state.setdefault("export_progress", 0) + + self.server.controller.add("on_server_ready")(self.on_server_ready) + + self._ui = None + + def on_server_ready(self, *args, **kwargs): + # Bind instance methods to state change + pass + + def on_export_clicked(self, event): + self.start_export(event["name"], event["full"]) + + def start_export(self, name, full): + if self._exporting_dataset(): + return + + self._export_task = asynchronous.create_task(self.export_dataset(name, full)) + + async def export_dataset(self, name, full): + with self.state: + self.state.export_status = "pending" + self.state.export_progress = 0 + await self.server.network_completion + + try: + await self._export_dataset(name, full) + with self.state: + self.state.export_status = "success" + except Exception: + with self.state: + self.state.export_status = "fail" + finally: + with self.state: + self.state.export_progress = 1 + # Update list of available datasets + self.state.repository_datasets = [ + str(path) for path in discover_datasets(self.context.repository) + ] + self.state.all_datasets = ( + self.state.input_datasets + self.state.repository_datasets + ) + self.state.all_datasets_options = dataset_select_options(self.state.all_datasets) + await self.server.network_completion + + async def _export_dataset(self, name, full): + tmp_dataset_dir = self.context.repository / "tmp" / name + dataset_dir = self.context.repository / name + recursive_rmdir(tmp_dataset_dir) + Path.mkdir(tmp_dataset_dir, parents=True) + + # Ensure the transform parameters are frozen for the duration of the task + transforms = [] + for t in self.context.transforms: + instance = t["instance"] + new_instance = instance.__class__() + new_instance.set_parameters(instance.get_parameters()) + transforms.append(new_instance) + + # transforms = list(map(lambda t: t["instance"], self.context.transforms)) + transform = trans.ChainedImageTransform(transforms) + + dataset = self.context.dataset + + if full: + image_ids = set(dataset.imgs.keys()) + else: + image_ids = set(self.context.dataset_ids) + + import kwcoco + + new_dataset = kwcoco.CocoDataset() + + # How often to update the progress + PROGRESS_UPDATE_STEP = 30 + # Ensure a directory doesn't have too many files + MAX_FILES_PER_DIRECTORY = 100 + + def subdir_generator(max_files): + i = 0 + + while True: + yield str(i // max_files) + i += 1 + + subdir = subdir_generator(MAX_FILES_PER_DIRECTORY) + + for i, image_id in enumerate(image_ids): + subdir_name = next(subdir) + destination_dir = tmp_dataset_dir / subdir_name + + if not Path.exists(destination_dir): + Path.mkdir(destination_dir, parents=True) + + img = dataset.get_image(image_id) + img.load() # Avoid OSError(24, 'Too many open files') + # transforms require RGB mode + img = img.convert("RGB") if img.mode != "RGB" else img + + if img.format is not None: + img_format = img.format + else: + img_format = "PNG" + + img_destination = destination_dir / f"{image_id}.{img_format.lower()}" + transformed_img = transform.execute(img) + transformed_img.save(img_destination, img_format) + + new_dataset.add_image(img_destination, id=image_id) + + if i % PROGRESS_UPDATE_STEP == 0: + with self.state: + self.state.export_progress = i / len(image_ids) + await self.server.network_completion + + for cat in dataset.cats.values(): + new_dataset.add_category(**cat) + + for ann in dataset.anns.values(): + if ann["image_id"] in image_ids: + new_dataset.add_annotation(**ann) + + new_dataset.fpath = tmp_dataset_dir / f"{name}.json" + new_dataset.reroot() + new_dataset.dump() + + recursive_rmdir(dataset_dir) + Path.rename(tmp_dataset_dir, dataset_dir) + + def _exporting_dataset(self): + return hasattr(self, "_export_task") and not self._export_task.done() + + def export_ui(self): + self.form_ui() + + def form_ui(self): + with html.Div(trame_server=self.server): + ExportWidget( + current_dataset=("current_dataset",), + repository_datasets=("repository_datasets",), + export_dataset=(self.on_export_clicked, "[$event]"), + status=("export_status",), + progress=("export_progress",), + ) + + @property + def ui(self): + if self._ui is None: + with QLayout(self.server) as layout: + self._ui = layout + + with quasar.QDrawer( + v_model=("leftDrawerOpen", True), + side="left", + elevated=True, + width="500", + ): + self.export_ui() + + with quasar.QPageContainer(): + with quasar.QPage(): + with html.Div(classes="row", style="min-height: inherit;"): + with html.Div(classes="col q-pa-md"): + pass + + return self._ui + + +def main(server=None, *args, **kwargs): + server = get_server() + server.client_type = "vue3" + + app = ExportApp(server) + app.ui + + server.start(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/src/nrtk_explorer/app/ui/__init__.py b/src/nrtk_explorer/app/ui/__init__.py index 7aa1a7c..592c4ef 100644 --- a/src/nrtk_explorer/app/ui/__init__.py +++ b/src/nrtk_explorer/app/ui/__init__.py @@ -1,6 +1,6 @@ from .layout import NrtkExplorerLayout from .image_list import ImageList -from .collapsible_card import CollapsibleCard +from .collapsible_card import CollapsibleCard, CollapsibleCardUnslotted def reload(m=None): @@ -17,4 +17,5 @@ def reload(m=None): "NrtkExplorerLayout", "ImageList", "CollapsibleCard", + "CollapsibleCardUnslotted", ] diff --git a/src/nrtk_explorer/app/ui/collapsible_card.py b/src/nrtk_explorer/app/ui/collapsible_card.py index 6a2bc53..1e8edef 100644 --- a/src/nrtk_explorer/app/ui/collapsible_card.py +++ b/src/nrtk_explorer/app/ui/collapsible_card.py @@ -29,3 +29,28 @@ def __init__(self, name=None, collapsed=False, **kwargs): with html.Div(v_show=(name, not collapsed)): self.slot_content = quasar.QCardSection() self.slot_actions = quasar.QCardActions(align="right") + + +class CollapsibleCardUnslotted(quasar.QCard): + def __init__(self, name=None, collapsed=False, **kwargs): + super().__init__(**kwargs) + + if name is None: + CollapsibleCard.id_count += 1 + name = f"is_card_open_{CollapsibleCard.id_count}" + self.state.client_only(name) # keep it local if not provided + + with self: + with quasar.QCardSection(): + with html.Div(classes="row items-center no-wrap"): + self.slot_title = html.Div(classes="col") + with html.Div(classes="col-auto"): + quasar.QBtn( + round=True, + flat=True, + dense=True, + click=f"{name} = !{name}", + icon=(f"{name} ? 'keyboard_arrow_up' : 'keyboard_arrow_down'",), + ) + with quasar.QSlideTransition(): + self.slot_collapse = html.Div(v_show=(name, not collapsed)) diff --git a/src/nrtk_explorer/app/ui/layout.py b/src/nrtk_explorer/app/ui/layout.py index 3fc3b84..250e605 100644 --- a/src/nrtk_explorer/app/ui/layout.py +++ b/src/nrtk_explorer/app/ui/layout.py @@ -1,4 +1,3 @@ -from pathlib import Path from trame.ui.quasar import QLayout from trame.widgets import quasar from trame.widgets import html @@ -8,13 +7,13 @@ VERTICAL_SPLIT_DEFAULT_VALUE = 40 -def parse_dataset_dirs(datasets): - return [{"label": Path(ds).name, "value": ds} for ds in datasets] - - class NrtkDrawer(html.Div): def __init__( - self, dataset_paths=[], embeddings_app=None, filtering_app=None, transforms_app=None + self, + embeddings_app=None, + filtering_app=None, + transforms_app=None, + export_app=None, ): super().__init__(classes="q-pa-md q-gutter-md") @@ -27,7 +26,7 @@ def __init__( quasar.QSelect( label="Dataset", v_model=("current_dataset",), - options=(parse_dataset_dirs(dataset_paths),), + options=("all_datasets_options",), filled=True, emit_value=True, map_options=True, @@ -84,6 +83,13 @@ def __init__( with card.slot_actions: transforms_app.apply_ui() + # Export + with ui.CollapsibleCardUnslotted() as card: + with card.slot_title: + html.Span("Export Dataset", classes="text-h6") + with card.slot_collapse: + export_app.export_ui() + # Filters with ui.CollapsibleCard() as card: with card.slot_title: @@ -126,10 +132,10 @@ def __init__( self, server, reload=None, - dataset_paths=None, embeddings_app=None, filtering_app=None, transforms_app=None, + export_app=None, **kwargs, ): super().__init__(server, view="lhh LpR lff", classes="shadow-2 rounded-borders bg-grey-2") @@ -150,10 +156,10 @@ def __init__( ) as split_drawer_main: with split_drawer_main.slot_before: NrtkDrawer( - dataset_paths=dataset_paths, embeddings_app=embeddings_app, filtering_app=filtering_app, transforms_app=transforms_app, + export_app=export_app, ) with split_drawer_main.slot_after: with Splitter( diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index 72ee843..f19f853 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -5,7 +5,7 @@ dataset = get_dataset("path/to/dataset.json") """ -from typing import Sequence as SequenceType +from typing import Sequence as SequenceType, Union from abc import ABC, abstractmethod import os from functools import lru_cache @@ -65,6 +65,25 @@ def get_image(self, id: int): return JsonDataset(path) +def discover_datasets(repository: Union[Path, None]) -> list[Path]: + datasets: list[Path] = [] + + if repository is None: + return datasets + + for item in repository.iterdir(): + if item.is_dir(): + for file in item.iterdir(): + if file.is_file() and file.suffix == ".json" and is_coco_dataset(str(file)): + datasets.append(file) + + return datasets + + +def dataset_select_options(datasets: list[str]): + return [{"label": Path(ds).name, "value": ds} for ds in datasets] + + def is_coco_dataset(path: str): if not os.path.exists(path): return False diff --git a/src/nrtk_explorer/widgets/nrtk_explorer.py b/src/nrtk_explorer/widgets/nrtk_explorer.py index 86ad7e8..47f8d49 100644 --- a/src/nrtk_explorer/widgets/nrtk_explorer.py +++ b/src/nrtk_explorer/widgets/nrtk_explorer.py @@ -89,3 +89,20 @@ def __init__(self, **kwargs): "update:operator", "update:invert", ] + + +class ExportWidget(HtmlElement): + def __init__(self, **kwargs): + super().__init__( + "export-widget", + **kwargs, + ) + self._attr_names += [ + ("current_dataset", "currentDataset"), + ("repository_datasets", "repositoryDatasets"), + "status", + "progress", + ] + self._event_names += [ + ("export_dataset", "exportDataset"), + ] diff --git a/vue-components/src/components/ExportWidget.vue b/vue-components/src/components/ExportWidget.vue new file mode 100644 index 0000000..7d9c7b4 --- /dev/null +++ b/vue-components/src/components/ExportWidget.vue @@ -0,0 +1,117 @@ + + + diff --git a/vue-components/src/components/index.js b/vue-components/src/components/index.js index 444762d..1ce2167 100644 --- a/vue-components/src/components/index.js +++ b/vue-components/src/components/index.js @@ -3,11 +3,13 @@ import ParamsWidget from './ParamsWidget.vue' import TransformsWidget from './TransformsWidget.vue' import FilterOptionsWidget from './FilterOptionsWidget.vue' import FilterOperatorWidget from './FilterOperatorWidget.vue' +import ExportWidget from './ExportWidget.vue' export default { scatterPlot: ScatterPlot, paramsWidget: ParamsWidget, transformsWidget: TransformsWidget, filterOptionsWidget: FilterOptionsWidget, - filterOperatorWidget: FilterOperatorWidget + filterOperatorWidget: FilterOperatorWidget, + exportWidget: ExportWidget }