Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to export transformed dataset #156

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ nrtk-explorer --dataset ./nrtk_explorer_datasets/coco-od-2017/mini_val2017.json
or a directory loadable by the [Dataset](https://huggingface.co/docs/datasets/index) library.
You can specify multiple datasets using a space as the
separator. Example: `nrtk-explorer --dataset ../foo-dir/coco.json cppe-5`
- `--repository` Specify an existing directory where exported datasets will be saved to and loaded from.
- `--download` Cache Hugging Face Hub datasets locally instead of streaming them.
When datasets are streamed, nrtk-explorer limits the number of loaded images.
- `--models` specify the Hugging Face Hub [object detection](https://huggingface.co/models?pipeline_tag=object-detection&library=transformers&sort=trending)
Expand Down
50 changes: 44 additions & 6 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import logging
from typing import Iterable

from pathlib import Path

from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import get_dataset, expand_hugging_face_datasets
from nrtk_explorer.library.dataset import (
get_dataset,
expand_hugging_face_datasets,
discover_datasets,
dataset_select_options,
)
from nrtk_explorer.library.debounce import debounce

from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.embeddings import EmbeddingsApp
from nrtk_explorer.app.export import ExportApp
from nrtk_explorer.app.transforms import TransformsApp
from nrtk_explorer.app.filtering import FilteringApp
from nrtk_explorer.app.applet import Applet
Expand All @@ -34,6 +42,14 @@
NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds


def dir_path(arg):
path = Path(arg).resolve()
if path.is_dir():
return path
else:
raise NotADirectoryError(arg)


# ---------------------------------------------------------
# Engine class
# ---------------------------------------------------------
Expand All @@ -50,6 +66,13 @@ def __init__(self, server=None):
help="Path to the JSON file describing the image dataset",
)

self.server.cli.add_argument(
"--repository",
default=None,
type=dir_path,
help="Path to the directory where exported datasets will be saved to.",
)

self.server.cli.add_argument(
"--download",
action="store_true",
Expand All @@ -58,11 +81,21 @@ def __init__(self, server=None):
)

known_args, _ = self.server.cli.parse_known_args()
dataset_identifiers = expand_hugging_face_datasets(

self.state.input_datasets = expand_hugging_face_datasets(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is input_datasets displayed by the GUI?

known_args.dataset, not known_args.download
)
self.input_paths = dataset_identifiers
self.state.current_dataset = self.input_paths[0]

self.context.repository = known_args.repository
self.state.repository_datasets = [
str(path) for path in discover_datasets(self.context.repository)
]

self.state.all_datasets = self.state.input_datasets + self.state.repository_datasets

self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)

self.state.current_dataset = self.state.all_datasets[0]

images = Images(server=self.server)

Expand All @@ -81,6 +114,10 @@ def __init__(self, server=None):
server=self.server.create_child_server(translator=filtering_translator),
)

self._export_app = ExportApp(
server=self.server.create_child_server(),
)

self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations)
self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered)
self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered)
Expand Down Expand Up @@ -153,7 +190,8 @@ def resample_images(self, **kwargs):
else:
selected_images = images

self.state.dataset_ids = [str(img["id"]) for img in selected_images]
self.context.dataset_ids = [img["id"] for img in selected_images]
self.state.dataset_ids = [str(image_id) for image_id in self.context.dataset_ids]

def _build_ui(self):
extra_args = {}
Expand All @@ -163,9 +201,9 @@ def _build_ui(self):

self.ui = ui.NrtkExplorerLayout(
server=self.server,
dataset_paths=self.input_paths,
embeddings_app=self._embeddings_app,
filtering_app=self._filtering_app,
transforms_app=self._transforms_app,
export_app=self._export_app,
**extra_args,
)
216 changes: 216 additions & 0 deletions src/nrtk_explorer/app/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from pathlib import Path

from trame.app import get_server, asynchronous
from trame.widgets import quasar, html
from trame.ui.quasar import QLayout

from nrtk_explorer.app.applet import Applet
from nrtk_explorer.library.dataset import (
discover_datasets,
dataset_select_options,
)
import nrtk_explorer.library.transforms as trans
from nrtk_explorer.widgets.nrtk_explorer import ExportWidget


def recursive_rmdir(path: Path):
if not path.is_dir():
return

for item in path.iterdir():
if item.is_file():
item.unlink()
else:
recursive_rmdir(item)

path.rmdir()


class ExportApp(Applet):
def __init__(self, server):
super().__init__(server)

self.context.setdefault("repository", "")
self.state.setdefault("current_dataset", "")
Comment on lines +33 to +34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is set in core.py too. Should we refactor so we initialize the shared setup in once place?

self.state.setdefault("repository_datasets", [])
self.state.setdefault("export_status", "idle")
self.state.setdefault("export_progress", 0)

self.server.controller.add("on_server_ready")(self.on_server_ready)

self._ui = None

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
pass
Comment on lines +43 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason to keep?


def on_export_clicked(self, event):
self.start_export(event["name"], event["full"])

def start_export(self, name, full):
if self._exporting_dataset():
return
Comment on lines +51 to +52
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User could have stale images in the export if they change transforms, then click export again?


self._export_task = asynchronous.create_task(self.export_dataset(name, full))

async def export_dataset(self, name, full):
with self.state:
self.state.export_status = "pending"
self.state.export_progress = 0
await self.server.network_completion

try:
await self._export_dataset(name, full)
with self.state:
self.state.export_status = "success"
except Exception:
with self.state:
self.state.export_status = "fail"
finally:
with self.state:
self.state.export_progress = 1
# Update list of available datasets
self.state.repository_datasets = [
str(path) for path in discover_datasets(self.context.repository)
]
self.state.all_datasets = (
self.state.input_datasets + self.state.repository_datasets
)
self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)
await self.server.network_completion

async def _export_dataset(self, name, full):
tmp_dataset_dir = self.context.repository / "tmp" / name
dataset_dir = self.context.repository / name
recursive_rmdir(tmp_dataset_dir)
Path.mkdir(tmp_dataset_dir, parents=True)

# Ensure the transform parameters are frozen for the duration of the task
transforms = []
for t in self.context.transforms:
instance = t["instance"]
new_instance = instance.__class__()
new_instance.set_parameters(instance.get_parameters())
transforms.append(new_instance)

# transforms = list(map(lambda t: t["instance"], self.context.transforms))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️

transform = trans.ChainedImageTransform(transforms)

dataset = self.context.dataset

if full:
image_ids = set(dataset.imgs.keys())
else:
image_ids = set(self.context.dataset_ids)

import kwcoco

new_dataset = kwcoco.CocoDataset()

# How often to update the progress
PROGRESS_UPDATE_STEP = 30
# Ensure a directory doesn't have too many files
MAX_FILES_PER_DIRECTORY = 100

def subdir_generator(max_files):
i = 0

while True:
yield str(i // max_files)
i += 1

subdir = subdir_generator(MAX_FILES_PER_DIRECTORY)

for i, image_id in enumerate(image_ids):
subdir_name = next(subdir)
destination_dir = tmp_dataset_dir / subdir_name

if not Path.exists(destination_dir):
Path.mkdir(destination_dir, parents=True)

img = dataset.get_image(image_id)
img.load() # Avoid OSError(24, 'Too many open files')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proably not needed in this loop?

# transforms require RGB mode
img = img.convert("RGB") if img.mode != "RGB" else img

if img.format is not None:
img_format = img.format
else:
img_format = "PNG"

img_destination = destination_dir / f"{image_id}.{img_format.lower()}"
transformed_img = transform.execute(img)
transformed_img.save(img_destination, img_format)

new_dataset.add_image(img_destination, id=image_id)

if i % PROGRESS_UPDATE_STEP == 0:
with self.state:
self.state.export_progress = i / len(image_ids)
await self.server.network_completion

for cat in dataset.cats.values():
new_dataset.add_category(**cat)

for ann in dataset.anns.values():
if ann["image_id"] in image_ids:
new_dataset.add_annotation(**ann)

new_dataset.fpath = tmp_dataset_dir / f"{name}.json"
new_dataset.reroot()
new_dataset.dump()

recursive_rmdir(dataset_dir)
Path.rename(tmp_dataset_dir, dataset_dir)

def _exporting_dataset(self):
return hasattr(self, "_export_task") and not self._export_task.done()

def export_ui(self):
self.form_ui()

def form_ui(self):
with html.Div(trame_server=self.server):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is trame_server kwarg needed?

ExportWidget(
current_dataset=("current_dataset",),
repository_datasets=("repository_datasets",),
export_dataset=(self.on_export_clicked, "[$event]"),
status=("export_status",),
progress=("export_progress",),
)

@property
def ui(self):
if self._ui is None:
with QLayout(self.server) as layout:
self._ui = layout

with quasar.QDrawer(
v_model=("leftDrawerOpen", True),
side="left",
elevated=True,
width="500",
):
self.export_ui()

with quasar.QPageContainer():
with quasar.QPage():
with html.Div(classes="row", style="min-height: inherit;"):
with html.Div(classes="col q-pa-md"):
pass

return self._ui


def main(server=None, *args, **kwargs):
server = get_server()
server.client_type = "vue3"

app = ExportApp(server)
app.ui

server.start(**kwargs)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/nrtk_explorer/app/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .layout import NrtkExplorerLayout
from .image_list import ImageList
from .collapsible_card import CollapsibleCard
from .collapsible_card import CollapsibleCard, CollapsibleCardUnslotted


def reload(m=None):
Expand All @@ -17,4 +17,5 @@ def reload(m=None):
"NrtkExplorerLayout",
"ImageList",
"CollapsibleCard",
"CollapsibleCardUnslotted",
]
25 changes: 25 additions & 0 deletions src/nrtk_explorer/app/ui/collapsible_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,28 @@ def __init__(self, name=None, collapsed=False, **kwargs):
with html.Div(v_show=(name, not collapsed)):
self.slot_content = quasar.QCardSection()
self.slot_actions = quasar.QCardActions(align="right")


class CollapsibleCardUnslotted(quasar.QCard):
def __init__(self, name=None, collapsed=False, **kwargs):
super().__init__(**kwargs)

if name is None:
CollapsibleCard.id_count += 1
name = f"is_card_open_{CollapsibleCard.id_count}"
self.state.client_only(name) # keep it local if not provided

with self:
with quasar.QCardSection():
with html.Div(classes="row items-center no-wrap"):
self.slot_title = html.Div(classes="col")
with html.Div(classes="col-auto"):
quasar.QBtn(
round=True,
flat=True,
dense=True,
click=f"{name} = !{name}",
icon=(f"{name} ? 'keyboard_arrow_up' : 'keyboard_arrow_down'",),
)
with quasar.QSlideTransition():
self.slot_collapse = html.Div(v_show=(name, not collapsed))
Loading
Loading