-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to export transformed dataset #156
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from pathlib import Path | ||
|
||
from trame.app import get_server, asynchronous | ||
from trame.widgets import quasar, html | ||
from trame.ui.quasar import QLayout | ||
|
||
from nrtk_explorer.app.applet import Applet | ||
from nrtk_explorer.library.dataset import ( | ||
discover_datasets, | ||
dataset_select_options, | ||
) | ||
import nrtk_explorer.library.transforms as trans | ||
from nrtk_explorer.widgets.nrtk_explorer import ExportWidget | ||
|
||
|
||
def recursive_rmdir(path: Path): | ||
if not path.is_dir(): | ||
return | ||
|
||
for item in path.iterdir(): | ||
if item.is_file(): | ||
item.unlink() | ||
else: | ||
recursive_rmdir(item) | ||
|
||
path.rmdir() | ||
|
||
|
||
class ExportApp(Applet): | ||
def __init__(self, server): | ||
super().__init__(server) | ||
|
||
self.context.setdefault("repository", "") | ||
self.state.setdefault("current_dataset", "") | ||
Comment on lines
+33
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is set in core.py too. Should we refactor so we initialize the shared setup in once place? |
||
self.state.setdefault("repository_datasets", []) | ||
self.state.setdefault("export_status", "idle") | ||
self.state.setdefault("export_progress", 0) | ||
|
||
self.server.controller.add("on_server_ready")(self.on_server_ready) | ||
|
||
self._ui = None | ||
|
||
def on_server_ready(self, *args, **kwargs): | ||
# Bind instance methods to state change | ||
pass | ||
Comment on lines
+43
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reason to keep? |
||
|
||
def on_export_clicked(self, event): | ||
self.start_export(event["name"], event["full"]) | ||
|
||
def start_export(self, name, full): | ||
if self._exporting_dataset(): | ||
return | ||
Comment on lines
+51
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. User could have stale images in the export if they change transforms, then click export again? |
||
|
||
self._export_task = asynchronous.create_task(self.export_dataset(name, full)) | ||
|
||
async def export_dataset(self, name, full): | ||
with self.state: | ||
self.state.export_status = "pending" | ||
self.state.export_progress = 0 | ||
await self.server.network_completion | ||
|
||
try: | ||
await self._export_dataset(name, full) | ||
with self.state: | ||
self.state.export_status = "success" | ||
except Exception: | ||
with self.state: | ||
self.state.export_status = "fail" | ||
finally: | ||
with self.state: | ||
self.state.export_progress = 1 | ||
# Update list of available datasets | ||
self.state.repository_datasets = [ | ||
str(path) for path in discover_datasets(self.context.repository) | ||
] | ||
self.state.all_datasets = ( | ||
self.state.input_datasets + self.state.repository_datasets | ||
) | ||
self.state.all_datasets_options = dataset_select_options(self.state.all_datasets) | ||
await self.server.network_completion | ||
|
||
async def _export_dataset(self, name, full): | ||
tmp_dataset_dir = self.context.repository / "tmp" / name | ||
dataset_dir = self.context.repository / name | ||
recursive_rmdir(tmp_dataset_dir) | ||
Path.mkdir(tmp_dataset_dir, parents=True) | ||
|
||
# Ensure the transform parameters are frozen for the duration of the task | ||
transforms = [] | ||
for t in self.context.transforms: | ||
instance = t["instance"] | ||
new_instance = instance.__class__() | ||
new_instance.set_parameters(instance.get_parameters()) | ||
transforms.append(new_instance) | ||
|
||
# transforms = list(map(lambda t: t["instance"], self.context.transforms)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🕵️ |
||
transform = trans.ChainedImageTransform(transforms) | ||
|
||
dataset = self.context.dataset | ||
|
||
if full: | ||
image_ids = set(dataset.imgs.keys()) | ||
else: | ||
image_ids = set(self.context.dataset_ids) | ||
|
||
import kwcoco | ||
|
||
new_dataset = kwcoco.CocoDataset() | ||
|
||
# How often to update the progress | ||
PROGRESS_UPDATE_STEP = 30 | ||
# Ensure a directory doesn't have too many files | ||
MAX_FILES_PER_DIRECTORY = 100 | ||
|
||
def subdir_generator(max_files): | ||
i = 0 | ||
|
||
while True: | ||
yield str(i // max_files) | ||
i += 1 | ||
|
||
subdir = subdir_generator(MAX_FILES_PER_DIRECTORY) | ||
|
||
for i, image_id in enumerate(image_ids): | ||
subdir_name = next(subdir) | ||
destination_dir = tmp_dataset_dir / subdir_name | ||
|
||
if not Path.exists(destination_dir): | ||
Path.mkdir(destination_dir, parents=True) | ||
|
||
img = dataset.get_image(image_id) | ||
img.load() # Avoid OSError(24, 'Too many open files') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. proably not needed in this loop? |
||
# transforms require RGB mode | ||
img = img.convert("RGB") if img.mode != "RGB" else img | ||
|
||
if img.format is not None: | ||
img_format = img.format | ||
else: | ||
img_format = "PNG" | ||
|
||
img_destination = destination_dir / f"{image_id}.{img_format.lower()}" | ||
transformed_img = transform.execute(img) | ||
transformed_img.save(img_destination, img_format) | ||
|
||
new_dataset.add_image(img_destination, id=image_id) | ||
|
||
if i % PROGRESS_UPDATE_STEP == 0: | ||
with self.state: | ||
self.state.export_progress = i / len(image_ids) | ||
await self.server.network_completion | ||
|
||
for cat in dataset.cats.values(): | ||
new_dataset.add_category(**cat) | ||
|
||
for ann in dataset.anns.values(): | ||
if ann["image_id"] in image_ids: | ||
new_dataset.add_annotation(**ann) | ||
|
||
new_dataset.fpath = tmp_dataset_dir / f"{name}.json" | ||
new_dataset.reroot() | ||
new_dataset.dump() | ||
|
||
recursive_rmdir(dataset_dir) | ||
Path.rename(tmp_dataset_dir, dataset_dir) | ||
|
||
def _exporting_dataset(self): | ||
return hasattr(self, "_export_task") and not self._export_task.done() | ||
|
||
def export_ui(self): | ||
self.form_ui() | ||
|
||
def form_ui(self): | ||
with html.Div(trame_server=self.server): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is |
||
ExportWidget( | ||
current_dataset=("current_dataset",), | ||
repository_datasets=("repository_datasets",), | ||
export_dataset=(self.on_export_clicked, "[$event]"), | ||
status=("export_status",), | ||
progress=("export_progress",), | ||
) | ||
|
||
@property | ||
def ui(self): | ||
if self._ui is None: | ||
with QLayout(self.server) as layout: | ||
self._ui = layout | ||
|
||
with quasar.QDrawer( | ||
v_model=("leftDrawerOpen", True), | ||
side="left", | ||
elevated=True, | ||
width="500", | ||
): | ||
self.export_ui() | ||
|
||
with quasar.QPageContainer(): | ||
with quasar.QPage(): | ||
with html.Div(classes="row", style="min-height: inherit;"): | ||
with html.Div(classes="col q-pa-md"): | ||
pass | ||
|
||
return self._ui | ||
|
||
|
||
def main(server=None, *args, **kwargs): | ||
server = get_server() | ||
server.client_type = "vue3" | ||
|
||
app = ExportApp(server) | ||
app.ui | ||
|
||
server.start(**kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
input_datasets
displayed by the GUI?