Skip to content

Commit

Permalink
feat(app): make it possible to enable/disable various features of the…
Browse files Browse the repository at this point in the history
… app
  • Loading branch information
alesgenova committed Mar 7, 2025
1 parent e098020 commit a579912
Show file tree
Hide file tree
Showing 15 changed files with 834 additions and 401 deletions.
4 changes: 4 additions & 0 deletions src/nrtk_explorer/app/applet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def ctrl(self) -> Controller:
@property
def context(self) -> State:
return self.server.context

@property
def ctx(self) -> State:
return self.server.context
190 changes: 145 additions & 45 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from pathlib import Path

from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import (
Expand All @@ -16,12 +15,16 @@
from nrtk_explorer.library.app_config import process_config


from nrtk_explorer.app.features import (
EnabledFeatures,
DEFAULT_FEATURES,
validate_feature_name,
validate_preset_name,
config_features_to_enabled_features,
config_preset_to_enabled_features,
)
from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.images.image_server import ImageServer
from nrtk_explorer.app.embeddings import EmbeddingsApp
from nrtk_explorer.app.export import ExportApp
from nrtk_explorer.app.transforms import TransformsApp
from nrtk_explorer.app.filtering import FilteringApp
from nrtk_explorer.app.applet import Applet
from nrtk_explorer.app import ui
import nrtk_explorer.test_data
Expand All @@ -33,10 +36,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

html.Template.slot_names.add("before")
html.Template.slot_names.add("after")


DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__)
DEFAULT_DATASETS = [
f"{DIR_NAME}/coco-od-2017/test_val2017.json",
Expand Down Expand Up @@ -83,6 +82,25 @@ def dir_path(arg):
"help": "Path to the directory where exported datasets will be saved to",
},
},
"features": {
"flags": ["--features"],
"params": {
"nargs": "+",
"default": None,
"required": False,
"type": validate_feature_name,
"help": "Space separated list of app features to enable.",
}
},
"preset": {
"flags": ["--preset"],
"params": {
"default": None,
"required": False,
"type": validate_preset_name,
"help": "Choose which application features to enable based on a preset name.",
}
}
}


Expand All @@ -101,36 +119,78 @@ def __init__(self, server=None, **kwargs):
str(path) for path in discover_datasets(self.context.repository)
]

if config['preset'] is not None:
self.ctx.enabled_features = config_preset_to_enabled_features(config['preset'])
else:
self.ctx.enabled_features = config_features_to_enabled_features(config['features'])

self.ctx.enabled_features

self.state.all_datasets = self.state.input_datasets + self.state.repository_datasets
self.state.all_datasets_options = dataset_select_options(self.state.all_datasets)
self.state.current_dataset = self.state.all_datasets[0]

images = Images(server=self.server)
self._image_server = ImageServer(server=self.server, images=images)

self._transforms_app = TransformsApp(
server=self.server.create_child_server(), images=images, **kwargs
)

self._embeddings_app = EmbeddingsApp(
server=self.server.create_child_server(),
images=images,
)

filtering_translator = Translator()
filtering_translator.add_translation("categories", "annotation_categories")
self._filtering_app = FilteringApp(
server=self.server.create_child_server(translator=filtering_translator),
)

self._export_app = ExportApp(
server=self.server.create_child_server(),
)

self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations)
self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered)
self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered)
self._filtering_app.set_on_apply_filter(self.on_filter_apply)
self._datasets_app = None
if self.datasets_enabled:
from nrtk_explorer.app.features.datasets import DatasetsApp
self._datasets_app = DatasetsApp(
server=self.server.create_child_server(), **kwargs
)
else:
# If datasets selection is disabled, we don't have a way to tweak the sampling
# the images in a dataset. Hence include all images
global NUM_IMAGES_DEFAULT
NUM_IMAGES_DEFAULT = float('inf')

self._transforms_app = None
self.state.transform_enabled = False
if self.transforms_enabled:
from nrtk_explorer.app.features.transforms import TransformsApp
self._transforms_app = TransformsApp(
server=self.server.create_child_server(), images=images, **kwargs
)

self._images_app = None
if self.images_enabled:
from nrtk_explorer.app.features.images import ImagesApp
self._images_app = ImagesApp(
server=self.server.create_child_server(), images=images, **kwargs
)

self._inference_app = None
if self.inference_enabled:
from nrtk_explorer.app.features.inference import InferenceApp
self._inference_app = InferenceApp(
server=self.server.create_child_server(), **kwargs
)

self._embeddings_app = None
if self.embeddings_enabled:
from nrtk_explorer.app.features.embeddings import EmbeddingsApp
self._embeddings_app = EmbeddingsApp(
server=self.server.create_child_server(),
images=images,
)

self._filtering_app = None
if self.filtering_enabled:
from nrtk_explorer.app.features.filtering import FilteringApp
filtering_translator = Translator()
filtering_translator.add_translation("categories", "annotation_categories")
self._filtering_app = FilteringApp(
server=self.server.create_child_server(translator=filtering_translator),
)
self.ctrl.apply_filter.add(self.on_filter_apply)

self._export_app = None
if self.export_enabled and self.context.repository is not None:
from nrtk_explorer.app.features.export import ExportApp
self._export_app = ExportApp(
server=self.server.create_child_server(),
)

# Bind instance methods to controller
self.ctrl.on_server_reload = self._build_ui
Expand Down Expand Up @@ -179,19 +239,56 @@ def handle_exceptions(self, e: Exception):
persistent=True,
)

@property
def enabled_features(self) -> EnabledFeatures:
enabled_features = self.ctx.enabled_features
if enabled_features is None:
return DEFAULT_FEATURES
else:
return enabled_features

@property
def datasets_enabled(self) -> bool:
return self.enabled_features.get("datasets", DEFAULT_FEATURES["datasets"])

@property
def images_enabled(self) -> bool:
return self.enabled_features.get("images", DEFAULT_FEATURES["images"])

@property
def embeddings_enabled(self) -> bool:
return self.enabled_features.get("embeddings", DEFAULT_FEATURES["embeddings"])

@property
def transforms_enabled(self) -> bool:
return self.enabled_features.get("transforms", DEFAULT_FEATURES["transforms"])

@property
def filtering_enabled(self) -> bool:
return self.enabled_features.get("filtering", DEFAULT_FEATURES["filtering"])

@property
def export_enabled(self) -> bool:
return self.enabled_features.get("export", DEFAULT_FEATURES["export"])

@property
def inference_enabled(self) -> bool:
return self.enabled_features.get("inference", DEFAULT_FEATURES["inference"])

def on_dataset_change(self, **kwargs):
self.state.dataset_ids = [] # sampled images
self.state.user_selected_ids = [] # ensure image update in transforms app via image list
self.context.dataset = get_dataset(self.state.current_dataset)
self.state.num_images_max = len(self.context.dataset.imgs)
self.state.num_images = min(self.state.num_images_max, self.state.num_images)
self.state.dirty("num_images") # Trigger resample_images()
self.state.random_sampling_disabled = False
self.state.num_images_disabled = False

self.state.annotation_categories = {
category["id"]: category for category in self.context.dataset.cats.values()
}
with self.state:
self.state.dataset_ids = [] # sampled images
self.state.user_selected_ids = [] # ensure image update in transforms app via image list
self.context.dataset = get_dataset(self.state.current_dataset)
self.state.num_images_max = len(self.context.dataset.imgs)
self.state.num_images = min(self.state.num_images_max, self.state.num_images)
self.state.dirty("num_images") # Trigger resample_images()
self.state.random_sampling_disabled = False
self.state.num_images_disabled = False

self.state.annotation_categories = {
category["id"]: category for category in self.context.dataset.cats.values()
}

def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs):
selected_ids = []
Expand All @@ -205,7 +302,7 @@ def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs):
if include:
selected_ids.append(dataset_id)

self._embeddings_app.on_select(selected_ids)
self.state.user_selected_ids = selected_ids

def resample_images(self, **kwargs):
ids = [image["id"] for image in self.context.dataset.imgs.values()]
Expand All @@ -231,8 +328,11 @@ def _build_ui(self):

self.ui = ui.NrtkExplorerLayout(
server=self.server,
datasets_app=self._datasets_app,
images_app=self._images_app,
embeddings_app=self._embeddings_app,
filtering_app=self._filtering_app,
inference_app=self._inference_app,
transforms_app=self._transforms_app,
export_app=self._export_app,
**extra_args,
Expand Down
85 changes: 85 additions & 0 deletions src/nrtk_explorer/app/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import TypedDict


class EnabledFeatures(TypedDict):
datasets: bool
inference: bool
images: bool
embeddings: bool
transforms: bool
export: bool
filtering: bool


ALL_FEATURES: EnabledFeatures = {
"datasets": True,
"inference": True,
"images": True,
"embeddings": True,
"transforms": True,
"export": True,
"filtering": True,
}

NO_FEATURES: EnabledFeatures = {
"datasets": False,
"inference": False,
"images": False,
"embeddings": False,
"transforms": False,
"export": False,
"filtering": False,
}

DEFAULT_FEATURES = ALL_FEATURES

VIEWER_PRESET: EnabledFeatures = {
"datasets": True,
"inference": False,
"images": True,
"embeddings": False,
"transforms": False,
"export": False,
"filtering": True,
}

FEATURE_PRESETS = {
"all": ALL_FEATURES,
"none": NO_FEATURES,
"viewer": VIEWER_PRESET,
}


def validate_feature_name(feature: str) -> str:
known_features = set(DEFAULT_FEATURES.keys())

if feature not in known_features:
raise ValueError(f"Unknown feature '{feature}'. Known features are {known_features}")

return feature


def validate_preset_name(preset: str) -> str:
known_presets = set(FEATURE_PRESETS.keys())

if preset not in known_presets:
raise ValueError(f"Unknown preset '{preset}'. Known presets are {known_presets}")

return preset

def config_features_to_enabled_features(features: list[str] | None) -> EnabledFeatures:
if features is None:
return DEFAULT_FEATURES

enabled_features = NO_FEATURES.copy()

for feature in features:
enabled_features[feature] = True

return enabled_features

def config_preset_to_enabled_features(preset: str | None) -> EnabledFeatures:
if preset is None:
return DEFAULT_FEATURES

return FEATURE_PRESETS[preset]
Loading

0 comments on commit a579912

Please sign in to comment.