Skip to content

Commit

Permalink
Add ability to export transformed dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
alesgenova committed Dec 13, 2024
1 parent 5428c1c commit 919dcdf
Show file tree
Hide file tree
Showing 10 changed files with 458 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ nrtk-explorer --dataset ./nrtk_explorer_datasets/coco-od-2017/mini_val2017.json
or a directory loadable by the [Dataset](https://huggingface.co/docs/datasets/index) library.
You can specify multiple datasets using a space as the
separator. Example: `nrtk-explorer --dataset ../foo-dir/coco.json cppe-5`
- `--repository` Specify an existing directory where exported datasets will be saved to and loaded from.
- `--download` Cache Hugging Face Hub datasets locally instead of streaming them.
When datasets are streamed, nrtk-explorer limits the number of loaded images.
- `--models` specify the Hugging Face Hub [object detection](https://huggingface.co/models?pipeline_tag=object-detection&library=transformers&sort=trending)
Expand Down
50 changes: 44 additions & 6 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import logging
from typing import Iterable

from pathlib import Path

from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import get_dataset, expand_hugging_face_datasets
from nrtk_explorer.library.dataset import (
get_dataset,
expand_hugging_face_datasets,
discover_datasets,
dataset_select_options,
)
from nrtk_explorer.library.debounce import debounce

from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.embeddings import EmbeddingsApp
from nrtk_explorer.app.export import ExportApp
from nrtk_explorer.app.transforms import TransformsApp
from nrtk_explorer.app.filtering import FilteringApp
from nrtk_explorer.app.applet import Applet
Expand All @@ -34,6 +42,14 @@
NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds


def dir_path(arg):
path = Path(arg).resolve()
if path.is_dir():
return path
else:
raise NotADirectoryError(arg)


# ---------------------------------------------------------
# Engine class
# ---------------------------------------------------------
Expand All @@ -50,6 +66,13 @@ def __init__(self, server=None):
help="Path to the JSON file describing the image dataset",
)

self.server.cli.add_argument(
"--repository",
default=None,
type=dir_path,
help="Path to the directory where exported datasets will be saved to.",
)

self.server.cli.add_argument(
"--download",
action="store_true",
Expand All @@ -58,11 +81,21 @@ def __init__(self, server=None):
)

known_args, _ = self.server.cli.parse_known_args()
dataset_identifiers = expand_hugging_face_datasets(

self.state.input_datasets = expand_hugging_face_datasets(
known_args.dataset, not known_args.download
)
self.input_paths = dataset_identifiers
self.state.current_dataset = self.input_paths[0]

self.context.repository = known_args.repository
self.state.repository_datasets = [
str(path) for path in discover_datasets(self.context.repository)
]

self.state.all_datasets = self.state.input_datasets + self.state.repository_datasets

self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)

self.state.current_dataset = self.state.all_datasets[0]

images = Images(server=self.server)

Expand All @@ -81,6 +114,10 @@ def __init__(self, server=None):
server=self.server.create_child_server(translator=filtering_translator),
)

self._export_app = ExportApp(
server=self.server.create_child_server(),
)

self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations)
self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered)
self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered)
Expand Down Expand Up @@ -153,7 +190,8 @@ def resample_images(self, **kwargs):
else:
selected_images = images

self.state.dataset_ids = [str(img["id"]) for img in selected_images]
self.context.dataset_ids = [img["id"] for img in selected_images]
self.state.dataset_ids = [str(image_id) for image_id in self.context.dataset_ids]

def _build_ui(self):
extra_args = {}
Expand All @@ -163,9 +201,9 @@ def _build_ui(self):

self.ui = ui.NrtkExplorerLayout(
server=self.server,
dataset_paths=self.input_paths,
embeddings_app=self._embeddings_app,
filtering_app=self._filtering_app,
transforms_app=self._transforms_app,
export_app=self._export_app,
**extra_args,
)
216 changes: 216 additions & 0 deletions src/nrtk_explorer/app/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from pathlib import Path

from trame.app import get_server, asynchronous
from trame.widgets import quasar, html
from trame.ui.quasar import QLayout

from nrtk_explorer.app.applet import Applet
from nrtk_explorer.library.dataset import (
discover_datasets,
dataset_select_options,
)
import nrtk_explorer.library.transforms as trans
from nrtk_explorer.widgets.nrtk_explorer import ExportWidget


def recursive_rmdir(path: Path):
if not path.is_dir():
return

for item in path.iterdir():
if item.is_file():
item.unlink()
else:
recursive_rmdir(item)

path.rmdir()


class ExportApp(Applet):
def __init__(self, server):
super().__init__(server)

self.context.setdefault("repository", "")
self.state.setdefault("current_dataset", "")
self.state.setdefault("repository_datasets", [])
self.state.setdefault("export_status", "idle")
self.state.setdefault("export_progress", 0)

self.server.controller.add("on_server_ready")(self.on_server_ready)

self._ui = None

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
pass

def on_export_clicked(self, event):
self.start_export(event["name"], event["full"])

def start_export(self, name, full):
if self._exporting_dataset():
return

self._export_task = asynchronous.create_task(self.export_dataset(name, full))

async def export_dataset(self, name, full):
with self.state:
self.state.export_status = "pending"
self.state.export_progress = 0
await self.server.network_completion

try:
await self._export_dataset(name, full)
with self.state:
self.state.export_status = "success"
except Exception:
with self.state:
self.state.export_status = "fail"
finally:
with self.state:
self.state.export_progress = 1
# Update list of available datasets
self.state.repository_datasets = [
str(path) for path in discover_datasets(self.context.repository)
]
self.state.all_datasets = (
self.state.input_datasets + self.state.repository_datasets
)
self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)
await self.server.network_completion

async def _export_dataset(self, name, full):
tmp_dataset_dir = self.context.repository / "tmp" / name
dataset_dir = self.context.repository / name
recursive_rmdir(tmp_dataset_dir)
Path.mkdir(tmp_dataset_dir, parents=True)

# Ensure the transform parameters are frozen for the duration of the task
transforms = []
for t in self.context.transforms:
instance = t["instance"]
new_instance = instance.__class__()
new_instance.set_parameters(instance.get_parameters())
transforms.append(new_instance)

# transforms = list(map(lambda t: t["instance"], self.context.transforms))
transform = trans.ChainedImageTransform(transforms)

dataset = self.context.dataset

if full:
image_ids = set(dataset.imgs.keys())
else:
image_ids = set(self.context.dataset_ids)

import kwcoco

new_dataset = kwcoco.CocoDataset()

# How often to update the progress
PROGRESS_UPDATE_STEP = 30
# Ensure a directory doesn't have too many files
MAX_FILES_PER_DIRECTORY = 100

def subdir_generator(max_files):
i = 0

while True:
yield str(i // max_files)
i += 1

subdir = subdir_generator(MAX_FILES_PER_DIRECTORY)

for i, image_id in enumerate(image_ids):
subdir_name = next(subdir)
destination_dir = tmp_dataset_dir / subdir_name

if not Path.exists(destination_dir):
Path.mkdir(destination_dir, parents=True)

img = dataset.get_image(image_id)
img.load() # Avoid OSError(24, 'Too many open files')
# transforms require RGB mode
img = img.convert("RGB") if img.mode != "RGB" else img

if img.format is not None:
img_format = img.format
else:
img_format = "PNG"

img_destination = destination_dir / f"{image_id}.{img_format.lower()}"
transformed_img = transform.execute(img)
transformed_img.save(img_destination, img_format)

new_dataset.add_image(img_destination, id=image_id)

if i % PROGRESS_UPDATE_STEP == 0:
with self.state:
self.state.export_progress = i / len(image_ids)
await self.server.network_completion

for cat in dataset.cats.values():
new_dataset.add_category(**cat)

for ann in dataset.anns.values():
if ann["image_id"] in image_ids:
new_dataset.add_annotation(**ann)

new_dataset.fpath = tmp_dataset_dir / f"{name}.json"
new_dataset.reroot()
new_dataset.dump()

recursive_rmdir(dataset_dir)
Path.rename(tmp_dataset_dir, dataset_dir)

def _exporting_dataset(self):
return hasattr(self, "_export_task") and not self._export_task.done()

def export_ui(self):
self.form_ui()

def form_ui(self):
with html.Div(trame_server=self.server):
ExportWidget(
current_dataset=("current_dataset",),
repository_datasets=("repository_datasets",),
export_dataset=(self.on_export_clicked, "[$event]"),
status=("export_status",),
progress=("export_progress",),
)

@property
def ui(self):
if self._ui is None:
with QLayout(self.server) as layout:
self._ui = layout

with quasar.QDrawer(
v_model=("leftDrawerOpen", True),
side="left",
elevated=True,
width="500",
):
self.export_ui()

with quasar.QPageContainer():
with quasar.QPage():
with html.Div(classes="row", style="min-height: inherit;"):
with html.Div(classes="col q-pa-md"):
pass

return self._ui


def main(server=None, *args, **kwargs):
server = get_server()
server.client_type = "vue3"

app = ExportApp(server)
app.ui

server.start(**kwargs)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/nrtk_explorer/app/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .layout import NrtkExplorerLayout
from .image_list import ImageList
from .collapsible_card import CollapsibleCard
from .collapsible_card import CollapsibleCard, CollapsibleCardUnslotted


def reload(m=None):
Expand All @@ -17,4 +17,5 @@ def reload(m=None):
"NrtkExplorerLayout",
"ImageList",
"CollapsibleCard",
"CollapsibleCardUnslotted",
]
25 changes: 25 additions & 0 deletions src/nrtk_explorer/app/ui/collapsible_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,28 @@ def __init__(self, name=None, collapsed=False, **kwargs):
with html.Div(v_show=(name, not collapsed)):
self.slot_content = quasar.QCardSection()
self.slot_actions = quasar.QCardActions(align="right")


class CollapsibleCardUnslotted(quasar.QCard):
def __init__(self, name=None, collapsed=False, **kwargs):
super().__init__(**kwargs)

if name is None:
CollapsibleCard.id_count += 1
name = f"is_card_open_{CollapsibleCard.id_count}"
self.state.client_only(name) # keep it local if not provided

with self:
with quasar.QCardSection():
with html.Div(classes="row items-center no-wrap"):
self.slot_title = html.Div(classes="col")
with html.Div(classes="col-auto"):
quasar.QBtn(
round=True,
flat=True,
dense=True,
click=f"{name} = !{name}",
icon=(f"{name} ? 'keyboard_arrow_up' : 'keyboard_arrow_down'",),
)
with quasar.QSlideTransition():
self.slot_collapse = html.Div(v_show=(name, not collapsed))
Loading

0 comments on commit 919dcdf

Please sign in to comment.