Skip to content

Commit

Permalink
Merge branch 'sort-fix' into embedding-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 20, 2025
2 parents dee044e + 3110baa commit f002578
Show file tree
Hide file tree
Showing 11 changed files with 19 additions and 125 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ For more details on setting up a development environment see [DEVELOPMENT docs](
1. Merge `main` to `release` with a _merge commit_.
2. Run "Create Release" workflow with workflow from `release` branch.
3. Merge `release` to `main` with a _merge commit_.
4. Check package versions in Conda Feedstock [meta.yaml file](https://github.com/conda-forge/nrtk-explorer-feedstock/blob/main/recipe/meta.yaml)

[1]: https://trame.readthedocs.io/en/latest/
[2]: https://www.kitware.com/
Expand Down
9 changes: 3 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"accelerate",
"numpy",
"Pillow",
"pybsm>=0.6,<=0.9.0",
"pybsm==0.10.2",
"scikit-learn>=1.6.0",
"smqtk_image_io",
"tabulate",
Expand All @@ -42,15 +42,12 @@ dependencies = [
"transformers",
"datasets[vision]",
"umap-learn",
"nrtk[headless]>=0.12.0,<=0.16.0",
"nrtk[headless]==0.19.1",
"trame-annotations>=0.4.0",
"kwcoco",
]

[project.optional-dependencies]
kwcoco= [
"kwcoco",
]

dev = [
"black",
"flake8",
Expand Down
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.library.scoring import partition
from nrtk_explorer.app.applet import Applet

from nrtk_explorer.app.images.image_ids import (
Expand Down Expand Up @@ -230,7 +231,6 @@ def update_transformed_images(self, id_to_image):
**self._stashed_points_transformations,
**updated_points,
}

self.update_points_transformations_state()

# called by category filter
Expand Down
4 changes: 0 additions & 4 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from trame_server import Server

import nrtk_explorer.library.transforms as trans
import nrtk_explorer.library.nrtk_transforms as nrtk_trans
import nrtk_explorer.library.yaml_transforms as nrtk_yaml
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor
from nrtk_explorer.library.app_config import process_config
Expand Down Expand Up @@ -204,9 +203,6 @@ def delete_meta_state(old_ids, new_ids):
"identity": trans.IdentityTransform,
}

if nrtk_trans.nrtk_transforms_available():
self._transform_classes["nrtk_pybsm"] = nrtk_trans.NrtkPybsmTransform

# Add transform from YAML definition
self._transform_classes.update(nrtk_yaml.generate_transforms())

Expand Down
4 changes: 2 additions & 2 deletions src/nrtk_explorer/library/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypedDict, List
from .dataset import JsonDataset
from .dataset import CocoDataset


def get_cat_id(dataset, annotation):
Expand Down Expand Up @@ -33,7 +33,7 @@ class Annotation(TypedDict, total=False):
bbox: List[float]


def to_annotation(dataset: JsonDataset, prediction: Prediction) -> Annotation:
def to_annotation(dataset: CocoDataset, prediction: Prediction) -> Annotation:
annotation: Annotation = {}

if "label" in prediction:
Expand Down
40 changes: 5 additions & 35 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import os
from functools import lru_cache
from pathlib import Path
import json
from PIL import Image
import kwcoco
from datasets import (
load_dataset,
get_dataset_infos,
Expand All @@ -36,42 +36,12 @@ def build_cat_index(self):
self.name_to_cat = {cat["name"]: cat for cat in self.cats.values()}


class JsonDataset(BaseDataset, CategoryIndex):
"""JSON-based COCO datasets."""

def __init__(self, path: str):
with open(path) as f:
self.data = json.load(f)
self.fpath = path
self.cats = {cat["id"]: cat for cat in self.data["categories"]}
self.anns = {ann["id"]: ann for ann in self.data["annotations"]}
self.imgs = {img["id"]: img for img in self.data["images"]}
self.build_cat_index()

def _get_image_fpath(self, selected_id: int):
dataset_dir = Path(self.fpath).parent
file_name = self.imgs[selected_id]["file_name"]
return str(dataset_dir / file_name)

class CocoDataset(kwcoco.CocoDataset, BaseDataset):
def get_image(self, id: int):
image_fpath = self._get_image_fpath(id)
image_fpath = self.get_image_fpath(id)
return Image.open(image_fpath)


def make_coco_dataset(path: str):
try:
import kwcoco

class CocoDataset(kwcoco.CocoDataset, BaseDataset):
def get_image(self, id: int):
image_fpath = self.get_image_fpath(id)
return Image.open(image_fpath)

return CocoDataset(path)
except ImportError:
return JsonDataset(path)


def is_coco_dataset(path: str):
if not os.path.exists(path) or os.path.isdir(path):
return False
Expand Down Expand Up @@ -102,7 +72,7 @@ def find_column_name(features, column_names):


class HuggingFaceDataset(BaseDataset, CategoryIndex):
"""Interface for Hugging Face datasets with a similar API to JsonDataset."""
"""Interface for Hugging Face datasets with a similar API to CocoDataset."""

def __init__(self, identifier: str):
self.imgs: dict[str, dict] = {}
Expand Down Expand Up @@ -244,7 +214,7 @@ def get_dataset(identifier: str):
absolute_path = str(Path(identifier).resolve())

if is_coco_dataset(absolute_path):
return make_coco_dataset(absolute_path)
return CocoDataset(absolute_path)

# Assume identifier is a Hugging Face Dataset
return HuggingFaceDataset(identifier)
72 changes: 1 addition & 71 deletions src/nrtk_explorer/library/nrtk_transforms.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,14 @@
from typing import Any, Dict, Optional, TYPE_CHECKING

import numpy as np
import logging
from PIL import Image as ImageModule
from PIL.Image import Image
from nrtk_explorer.library.transforms import ImageTransform, ParameterDescription

ENABLED_NRTK_TRANSFORMS = True

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

try:
from pybsm.otf import dark_current_from_density
from nrtk.impls.perturb_image.pybsm.perturber import PybsmPerturber, PybsmSensor, PybsmScenario
from nrtk.impls.perturb_image.pybsm.perturber import PybsmSensor, PybsmScenario
except ImportError:
logger.info("Disabling NRTK transforms due to missing library/failing imports")
ENABLED_NRTK_TRANSFORMS = False

if TYPE_CHECKING:
PybsmPerturberType = PybsmPerturber
else:
PybsmPerturberType = None

PybsmPerturberArg = Optional[PybsmPerturberType]


def nrtk_transforms_available():
return ENABLED_NRTK_TRANSFORMS


# copied from https://github.com/Kitware/nrtk/blob/main/tests/impls/test_pybsm_utils.py
Expand Down Expand Up @@ -144,54 +125,3 @@ def create_sample_scenario():

def create_sample_sensor_and_scenario():
return dict(sensor=create_sample_sensor(), scenario=create_sample_scenario())


class NrtkPybsmTransform(ImageTransform):
def __init__(self, perturber: PybsmPerturberArg = None):
if perturber is None:
kwargs = create_sample_sensor_and_scenario()
perturber = PybsmPerturber(**kwargs)

self._perturber: PybsmPerturber = perturber

def get_parameters(self) -> dict[str, Any]:
return {
"D": self._perturber.sensor.D,
"f": self._perturber.sensor.f,
}

def set_parameters(self, params: Dict[str, Any]):
self._perturber.sensor.D = params["D"]
self._perturber.sensor.f = params["f"]

@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
aperture_description: ParameterDescription = {
"type": "float",
"label": "Effective Aperture (m)",
"default": None,
"description": None,
"options": None,
}

focal_description: ParameterDescription = {
"type": "float",
"label": "Focal Length (m)",
"default": None,
"description": None,
"options": None,
}

return {
"D": aperture_description,
"f": focal_description,
}

def execute(self, input: Image, *input_args: Any) -> Image:
if len(input_args) == 0:
input_args = ({"img_gsd": 0.15},)

input_array = np.asarray(input)
output_array = self._perturber.perturb(input_array, *input_args)

return ImageModule.fromarray(output_array)
4 changes: 2 additions & 2 deletions src/nrtk_explorer/library/nrtk_transforms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ nrtk_pybsm_detector_otf:
type: float
label: Focal length (m)

nrtk_pybsm_2:
nrtk_pybsm:
perturber: nrtk.impls.perturb_image.pybsm.perturber.PybsmPerturber
perturber_kwargs: nrtk_explorer.library.nrtk_transforms.create_sample_sensor_and_scenario
exec_default_args: [{ img_gsd: 0.15 }]
exec_default_args: [None, { img_gsd: 0.15 }]
description:
D:
_path: [sensor, D]
Expand Down
2 changes: 1 addition & 1 deletion src/nrtk_explorer/library/yaml_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ def execute(self, input, *input_args):
input_args = self.exec_args

input_array = np.asarray(input)
output_array = self._perturber.perturb(input_array, *input_args)
output_array, _ = self._perturber.perturb(input_array, *input_args)

return ImageModule.fromarray(output_array)
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from nrtk_explorer.library.dataset import get_dataset, JsonDataset
from nrtk_explorer.library.dataset import get_dataset, CocoDataset
import nrtk_explorer.test_data

from pathlib import Path
Expand Down Expand Up @@ -27,7 +27,7 @@ def test_get_dataset_empty():


def test_DefaultDataset(dataset_path):
ds = JsonDataset(dataset_path)
ds = CocoDataset(dataset_path)
assert len(ds.imgs) > 0
assert len(ds.cats) > 0
assert len(ds.anns) > 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def test_gaussian_blur():

def test_pybsm():
transforms = generate_transforms()
pybsm = transforms["nrtk_pybsm_2"]()
pybsm = transforms["nrtk_pybsm"]()
pybsm.set_parameters({"D": 0.25, "f": 4.0})
pybsm.execute(get_image())

0 comments on commit f002578

Please sign in to comment.