From f610cf9b74b1484ee4d03f21c61867c615b7b65a Mon Sep 17 00:00:00 2001 From: grayson-helmholz <107720976+grayson-helmholz@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:03:05 +0200 Subject: [PATCH] MAINT: have `qn_domains`-keys in Node/EdgeSettings be typed (#292) * now compatible with python3.12 --- docs/conf.py | 23 +++++++++++++++++++++++ src/qrules/quantum_numbers.py | 32 ++++++++++++++++++++++++++++++++ src/qrules/solving.py | 6 ++++-- 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 890d61fd..13ac2193 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -2,6 +2,7 @@ import os import sys +import typing from sphinx_api_relink.helpers import ( get_branch_name, @@ -11,9 +12,20 @@ set_intersphinx_version_remapping, ) +from qrules.quantum_numbers import EdgeQuantumNumbers, NodeQuantumNumbers + sys.path.insert(0, os.path.abspath(".")) from _extend_docstrings import extend_docstrings # noqa: PLC2701 + +def pick_newtype_attrs(some_type: type) -> list: + return [ + attr + for attr in dir(some_type) + if type(getattr(some_type, attr)) is typing.NewType + ] + + extend_docstrings() set_intersphinx_version_remapping({ "ipython": { @@ -261,6 +273,16 @@ nb_execution_show_tb = True nb_execution_timeout = -1 nb_output_stderr = "remove" + + +nitpick_temp_names = [ + *pick_newtype_attrs(EdgeQuantumNumbers), + *pick_newtype_attrs(NodeQuantumNumbers), +] +nitpick_temp_patterns = [ + (r"py:(class|obj)", r"qrules\.quantum_numbers\." + name) + for name in nitpick_temp_names +] nitpick_ignore_regex = [ (r"py:(class|obj)", "json.encoder.JSONEncoder"), (r"py:(class|obj)", r"qrules\.topology\.EdgeType"), @@ -269,6 +291,7 @@ (r"py:(class|obj)", r"qrules\.topology\.NewNodeType"), (r"py:(class|obj)", r"qrules\.topology\.NodeType"), (r"py:(class|obj)", r"qrules\.topology\.VT"), + *nitpick_temp_patterns, ] nitpicky = True primary_domain = "py" diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index ba0d6129..a64a65af 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -127,6 +127,29 @@ class EdgeQuantumNumbers: EdgeQuantumNumbers.g_parity, ] +# for accessing the keys of the dicts in EdgeSettings +EdgeQuantumNumberTypes = Union[ + type[EdgeQuantumNumbers.pid], + type[EdgeQuantumNumbers.mass], + type[EdgeQuantumNumbers.width], + type[EdgeQuantumNumbers.spin_magnitude], + type[EdgeQuantumNumbers.spin_projection], + type[EdgeQuantumNumbers.charge], + type[EdgeQuantumNumbers.isospin_magnitude], + type[EdgeQuantumNumbers.isospin_projection], + type[EdgeQuantumNumbers.strangeness], + type[EdgeQuantumNumbers.charmness], + type[EdgeQuantumNumbers.bottomness], + type[EdgeQuantumNumbers.topness], + type[EdgeQuantumNumbers.baryon_number], + type[EdgeQuantumNumbers.electron_lepton_number], + type[EdgeQuantumNumbers.muon_lepton_number], + type[EdgeQuantumNumbers.tau_lepton_number], + type[EdgeQuantumNumbers.parity], + type[EdgeQuantumNumbers.c_parity], + type[EdgeQuantumNumbers.g_parity], +] + @frozen(init=False) class NodeQuantumNumbers: @@ -155,6 +178,15 @@ class NodeQuantumNumbers: ] """Type hint for quantum numbers of interaction nodes.""" +# for accessing the keys of the dicts in NodeSettings +NodeQuantumNumberTypes = Union[ + type[NodeQuantumNumbers.l_magnitude], + type[NodeQuantumNumbers.l_projection], + type[NodeQuantumNumbers.s_magnitude], + type[NodeQuantumNumbers.s_projection], + type[NodeQuantumNumbers.parity_prefactor], +] + def _to_optional_float(optional_float: float | None) -> float | None: if optional_float is None: diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 77bec661..de0e85f6 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -33,7 +33,9 @@ from qrules.quantum_numbers import ( EdgeQuantumNumber, EdgeQuantumNumbers, + EdgeQuantumNumberTypes, NodeQuantumNumber, + NodeQuantumNumberTypes, ) from qrules.topology import MutableTransition, Topology @@ -50,7 +52,7 @@ class EdgeSettings: conservation_rules: set[GraphElementRule] = field(factory=set) rule_priorities: dict[GraphElementRule, int] = field(factory=dict) - qn_domains: dict[Any, list] = field(factory=dict) + qn_domains: dict[EdgeQuantumNumberTypes, list] = field(factory=dict) @implement_pretty_repr @@ -70,7 +72,7 @@ class NodeSettings: conservation_rules: set[Rule] = field(factory=set) rule_priorities: dict[Rule, int] = field(factory=dict) - qn_domains: dict[Any, list] = field(factory=dict) + qn_domains: dict[NodeQuantumNumberTypes, list] = field(factory=dict) interaction_strength: float = 1.0