Skip to content

Commit

Permalink
Refactor the way analysis settings are set (#110)
Browse files Browse the repository at this point in the history
* Update settings specification for analysis wrappers

- Only allow specification through keyword arguments, use `with_settings` classmethod to use a settings object
- Use __init_subclass__ to automatically determine the settings object type from the class signature
- Use __init_subclass__ to specify class-level settings as arguments instead of attributes

* Add test to check if analysis wrapper's __init__ only accepts keyword-only parameters

* Allow batch updating settings with update_settings

---------

Co-authored-by: lvanvught <[email protected]>
  • Loading branch information
crnh and LucVV authored Feb 7, 2025
1 parent cd3c92b commit 858a2ec
Show file tree
Hide file tree
Showing 27 changed files with 345 additions and 179 deletions.
17 changes: 10 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ build-backend = "hatchling.build"
[project]
name = "zospy"
authors = [
{name = "Luc van Vught"},
{name = "Jan-Willem Beenakker"},
{name = "Corné Haasjes"}
{ name = "Luc van Vught" },
{ name = "Jan-Willem Beenakker" },
{ name = "Corné Haasjes" }
]
maintainers = [
{name = "MReye research group", email = "[email protected]"}
{ name = "MReye research group", email = "[email protected]" }
]

description = "A Python package used to communicate with Zemax OpticStudio through the API"
readme = "README.md"
license = {file = "LICENSE.txt"}
license = { file = "LICENSE.txt" }
keywords = ["Zemax", "OpticStudio", "API", "ZOSAPI"]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand All @@ -30,7 +30,8 @@ dependencies = [
"pydantic >= 2.4.0",
"numpy",
"semver >= 3.0.0,<4",
"eval_type_backport", # TODO: Remove when dropping support for Python 3.9
"eval_type_backport; python_version <= '3.9'", # TODO: Remove when dropping support for Python 3.9
"typing_extensions; python_version <= '3.10'"
]
dynamic = ["version"]

Expand All @@ -57,6 +58,9 @@ path = "zospy/__init__.py"
[tool.hatch.envs.default]
python = "3.12"
installer = "uv"
path = ".venv"
dependencies = ["pytest"]


[tool.hatch.envs.default.scripts]
test-extension = "hatch test --extension {args}"
Expand Down Expand Up @@ -118,7 +122,6 @@ extend-include = [
exclude = [
"zospy/api/_ZOSAPI",
"zospy/api/_ZOSAPI_constants",

# TODO: Change this when movind the old analyses to zospy.analyses.old
"zospy/analyses/base.py",
"zospy/analyses/extendedscene.py",
Expand Down
10 changes: 5 additions & 5 deletions tests/analyses/new/parsers/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def test_validated_ndarray_invalid_list(self):
validated_ndarray.validate_python([[1, 2, 3], [4, 5]])


class TestZOSAPIConstant:
@staticmethod
def _get_instances():
return [obj for obj in gc.get_objects() if isinstance(obj, ZOSAPIConstantAnnotation)]
def _get_zosapi_constant_instances():
return [obj for obj in gc.get_objects() if isinstance(obj, ZOSAPIConstantAnnotation)]


class TestZOSAPIConstant:
@staticmethod
def _hasattr(obj, attr):
for name in attr.split("."):
Expand All @@ -97,6 +97,6 @@ def _hasattr(obj, attr):

return True

@pytest.mark.parametrize("annotation", _get_instances())
@pytest.mark.parametrize("annotation", _get_zosapi_constant_instances())
def test_constant_exists(self, zos, annotation): # noqa: ARG002
assert self._hasattr(constants, annotation.enum)
93 changes: 84 additions & 9 deletions tests/analyses/new/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
import json
from dataclasses import fields
Expand All @@ -7,6 +9,7 @@
import numpy as np
import pytest
from pandas import DataFrame
from pydantic import Field
from pydantic.dataclasses import dataclass
from pydantic.fields import FieldInfo

Expand All @@ -15,14 +18,24 @@
AnalysisData,
AnalysisMetadata,
AnalysisResult,
AnalysisSettings,
BaseAnalysisWrapper,
_validated_setter,
)
from zospy.analyses.new.decorators import analysis_settings
from zospy.analyses.new.parsers.types import ValidatedDataFrame
from zospy.analyses.new.reports.surface_data import SurfaceDataSettings
from zospy.analyses.new.systemviewers.base import SystemViewerWrapper


def all_subclasses(cls):
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)])


analysis_wrapper_classes = all_subclasses(BaseAnalysisWrapper)
analysis_wrapper_classes.remove(SystemViewerWrapper)


class TestValidatedSetter:
class MockSettings:
int_setting: int = 1
Expand Down Expand Up @@ -53,20 +66,16 @@ def test_set_non_existing(self):
settings.non_existing = 2


analysis_wrapper_classes = BaseAnalysisWrapper.__subclasses__()
analysis_wrapper_classes.remove(SystemViewerWrapper)


@dataclass
class MockAnalysisData:
int_data: int = 1
string_data: str = "a"


@dataclass
@analysis_settings
class MockAnalysisSettings:
int_setting: int = 1
string_setting: str = "a"
int_setting: int = Field(default=1, description="An integer setting")
string_setting: str = Field(default="a", description="A string setting")


class MockAnalysis(BaseAnalysisWrapper[MockAnalysisData, MockAnalysisSettings]):
Expand All @@ -75,8 +84,14 @@ class MockAnalysis(BaseAnalysisWrapper[MockAnalysisData, MockAnalysisSettings]):
_needs_config_file = False
_needs_text_output_file = False

def __init__(self, int_setting: int = 1, string_setting: str = "a", *, block_remove_temp_files: bool = False):
super().__init__(MockAnalysisSettings(), locals())
def __init__(
self,
*,
int_setting: int = 1,
string_setting: str = "a",
block_remove_temp_files: bool = False,
):
super().__init__(settings_kws=locals())

self.block_remove_temp_files = block_remove_temp_files

Expand Down Expand Up @@ -112,11 +127,21 @@ def get_settings_defaults(settings_class):

return result

def test_get_settings_type(self):
assert MockAnalysis._settings_type == MockAnalysisSettings # noqa: SLF001

def test_settings_type_is_specified(self):
assert MockAnalysis._settings_type is not AnalysisSettings # noqa: SLF001

@pytest.mark.parametrize("cls", analysis_wrapper_classes)
def test_analyses_correct_analysis_name(self, cls):
assert cls.TYPE is not None
assert hasattr(constants.Analysis.AnalysisIDM, cls.TYPE)

@pytest.mark.parametrize("cls", analysis_wrapper_classes)
def test_init_all_keyword_only_parameters(self, cls):
all(p.kind.name == "KEYWORD_ONLY" for _, p in inspect.signature(cls).parameters.items())

@pytest.mark.parametrize("cls", analysis_wrapper_classes)
def test_init_contains_all_settings(self, cls):
if cls().settings is None:
Expand All @@ -139,6 +164,56 @@ def test_analyses_default_values(self, cls):
assert field_name in init_signature.parameters
assert init_signature.parameters[field_name].default == default_value

def test_change_settings_from_parameters(self):
analysis = MockAnalysis(int_setting=2, string_setting="b")

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_change_settings_from_object(self):
settings = MockAnalysisSettings(int_setting=2, string_setting="b")
analysis = MockAnalysis.with_settings(settings)

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_settings_object_is_copied(self):
settings = MockAnalysisSettings(int_setting=2, string_setting="b")
analysis = MockAnalysis.with_settings(settings)

assert analysis.settings is not settings
assert analysis.settings == settings

def test_update_settings_object(self):
analysis = MockAnalysis(int_setting=1, string_setting="a")

analysis.update_settings(settings=MockAnalysisSettings(int_setting=2, string_setting="b"))

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_update_settings_dictionary(self):
analysis = MockAnalysis(int_setting=1, string_setting="a")

analysis.update_settings(settings_kws={"int_setting": 2, "string_setting": "b"})

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_update_settings_object_and_dictionary(self):
analysis = MockAnalysis(int_setting=1, string_setting="a")

analysis.update_settings(
settings=MockAnalysisSettings(int_setting=2, string_setting="a"), settings_kws={"string_setting": "b"}
)

assert analysis.settings.int_setting == 2
assert analysis.settings.string_setting == "b"

def test_update_settings_no_dataclass_raises_type_error(self):
with pytest.raises(TypeError, match="settings should be a dataclass"):
MockAnalysis().update_settings(settings=123)

@pytest.mark.parametrize(
"temp_file_type,filename",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/analyses/new/test_systemviewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MockSystemViewerSettings:

class MockSystemViewer(SystemViewerWrapper[MockSystemViewerSettings]):
def __init__(self, *, number: int = 5, settings: TestBase.MockSystemViewerSettings | None = None):
super().__init__(settings or TestBase.MockSystemViewerSettings(), locals())
super().__init__(locals())

def _create_analysis(self, *, settings_first=True): # noqa: ARG002
self._analysis = SimpleNamespace(
Expand Down
114 changes: 104 additions & 10 deletions zospy/analyses/new/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@

from __future__ import annotations

import dataclasses
import os
import weakref
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, is_dataclass
from dataclasses import dataclass, is_dataclass
from datetime import datetime # noqa: TCH003 Pydantic needs datetime to be present at runtime
from enum import Enum
from importlib import import_module
from pathlib import Path
from tempfile import mkstemp
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, cast, get_args

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -544,8 +545,24 @@ class BaseAnalysisWrapper(ABC, Generic[AnalysisData, AnalysisSettings]):
_needs_config_file: bool = False
_needs_text_output_file: bool = False

def __init__(self, settings: AnalysisSettings, settings_arguments: dict[str, any]):
self._init_settings(settings, settings_arguments)
def __init__(self, *, settings_kws: dict[str, any] | None = None):
"""Create a new analysis wrapper.
Settings can be changed by passing the settings as keyword arguments. Use the `with_settings` method to specify
the settings using a settings object.
Parameters
----------
settings_kws : dict[str, any]
Arguments to set the settings of the analysis.
Raises
------
ValueError
If `settings` is not a dataclass.
"""
self._settings = self._default_settings()
self.update_settings(settings_kws=settings_kws)

self._config_file = None
self._text_output_file = None
Expand All @@ -555,18 +572,95 @@ def __init__(self, settings: AnalysisSettings, settings_arguments: dict[str, any
self._remove_config_file = False
self._remove_text_output_file = False

def _init_settings(self, settings: AnalysisSettings, parameters: dict[str, any]):
self._settings = settings
def __init_subclass__(
cls,
*,
analysis_type: str | None = None,
mode: Literal["Sequential", "Nonsequential"] | None = None,
needs_config_file: bool = False,
needs_text_output_file: bool = False,
**kwargs,
):
"""Determine the settings type and class-level configuration of the analysis."""
cls.TYPE = analysis_type
cls.MODE = mode
cls._needs_config_file = needs_config_file
cls._needs_text_output_file = needs_text_output_file

if not hasattr(cls, "_settings_type"):
if hasattr(cls, "__orig_bases__"):
base = cls.__orig_bases__[0]
cls._settings_type: type[AnalysisSettings] = get_args(base)[1]
else:
cls._settings_type = type(None) # TODO: change to NoneType when dropping support for Python 3.9

super().__init_subclass__(**kwargs)

def update_settings(
self, *, settings: AnalysisSettings | None = None, settings_kws: dict[str, any] | None = None
) -> None:
"""Update the settings of the analysis using a settings object or keyword arguments.
Settings can be specified as an object and as keyword arguments. If both are specified, the keyword arguments
take precedence. If no settings are specified, the default settings are used. Furthermore, instead of using
a reference to the settings object, a new settings object is created with the specified parameters. This is done
to avoid modifying the original settings object.
Parameters
----------
settings : AnalysisSettings
Analysis settings object.
settings_kws
Dictionary with the settings parameters.
Raises
------
ValueError
If `settings` is not a dataclass.
"""
# Use the existing settings if no settings are specified
settings = settings or self.settings

if settings is None:
# Analysis does not have settings
return

if not is_dataclass(settings):
raise ValueError("settings should be a dataclass.")
raise TypeError("settings should be a dataclass.")

# Create a new settings object with the specified parameters. If no parameters are specified, this creates a
# copy of the settings object. This is done to avoid modifying the original settings object.
self._settings = dataclasses.replace(settings, **(settings_kws or {}))

@classmethod
def _default_settings(cls) -> AnalysisSettings:
"""Get the default settings of the analysis.
Returns
-------
AnalysisSettings
The default settings.
"""
return cls._settings_type() if cls._settings_type is not None else None

@classmethod
def with_settings(cls, settings: AnalysisSettings):
"""Create a new analysis with the specified settings.
Parameters
----------
settings : AnalysisSettings
Settings of the analysis.
Returns
-------
BaseAnalysisWrapper
The analysis wrapper.
"""
instance = cls()
instance.update_settings(settings=settings)

for field in fields(settings):
if field.name in parameters:
setattr(self.settings, field.name, parameters[field.name])
return instance

@property
def settings(self) -> AnalysisSettings:
Expand Down
Loading

0 comments on commit 858a2ec

Please sign in to comment.