diff --git a/.gitignore b/.gitignore index e4e60a81..f7b7246a 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ version.py # Temporary files *.pyc +*.swp *.tmp *condaenv.* .coverage @@ -34,6 +35,8 @@ __pycache__/ htmlcov/ oryx-build-commands.txt prof/ +tags +TAGS # Virtual environments *venv/ diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index 495886e2..ad4456fe 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -507,15 +507,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Or any other properties of a {class}`.State`:" + "Or any other properties of a {class}`.State`, such as masses or $J^{PC}(I^G)$ numbers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "tags": [ - "hide-input" + "hide-input", + "scroll-output" ] }, "outputs": [], @@ -539,6 +543,73 @@ "graphviz.Source(dot)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "scroll-input", + "hide-input" + ] + }, + "outputs": [], + "source": [ + "from fractions import Fraction\n", + "\n", + "\n", + "def render_jpc_ig(state: State) -> str:\n", + " particle = state.particle\n", + " text = render_fraction(particle.spin)\n", + " if particle.parity is not None:\n", + " text += render_sign(particle.parity)\n", + " if particle.c_parity is not None:\n", + " text += render_sign(particle.c_parity)\n", + " if particle.isospin is not None and particle.g_parity is not None:\n", + " text += \"(\"\n", + " text += f\"{render_fraction(particle.isospin.magnitude)}\" # with opening brace\n", + " text += f\"{render_sign(particle.g_parity)}\" # with closing brace\n", + " text += \")\"\n", + " return text\n", + "\n", + "\n", + "def render_fraction(value: float) -> str:\n", + " fraction = Fraction(value)\n", + " if fraction.denominator == 1:\n", + " return str(fraction.numerator)\n", + " return f\"{fraction.numerator}/{fraction.denominator}\"\n", + "\n", + "\n", + "def render_sign(parity: int) -> str:\n", + " if parity == -1:\n", + " return \"⁻\"\n", + " if parity == +1:\n", + " return \"⁺\"\n", + " raise NotImplementedError\n", + "\n", + "\n", + "jpc_ig_transitions = sorted({\n", + " t.convert(\n", + " state_converter=render_jpc_ig,\n", + " interaction_converter=lambda _: None,\n", + " )\n", + " for t in reaction.transitions\n", + "})\n", + "dot = qrules.io.asdot(jpc_ig_transitions, collapse_graphs=True)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{tip}\n", + "Note that collapsing the graphs also works for custom edge properties.\n", + ":::" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -601,7 +672,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.19" } }, "nbformat": 4, diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 7b9d9491..f7ab24c1 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -429,10 +429,7 @@ def _(obj: tuple) -> str: if all(isinstance(o, (float, int)) for o in obj): spin = Spin(*obj) return _spin_to_str(spin) - if all(isinstance(o, Particle) for o in obj): - return "\n".join(map(as_string, obj)) - _LOGGER.warning(f"No DOT render implemented for tuple of size {len(obj)}") - return str(obj) + return "\n".join(map(as_string, obj)) def _get_particle_graphs( @@ -487,8 +484,8 @@ def __to_particle(state: Any) -> Particle: def _collapse_graphs( - graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], -) -> list[FrozenTransition[tuple[Particle, ...], None]]: + graphs: Iterable[Transition[Any, Any]], +) -> list[FrozenTransition[tuple, None]]: transition_groups: dict[Topology, MutableTransition[set[Particle], None]] = { g.topology: MutableTransition( g.topology, @@ -501,11 +498,7 @@ def _collapse_graphs( topology = transition.topology group = transition_groups[topology] for state_id, state in transition.states.items(): - if isinstance(state, State): - particle = state.particle - else: - particle, _ = state - group.states[state_id].add(particle) + group.states[state_id].add(_strip_properties(state)) collected_graphs: list[FrozenTransition[tuple[Particle, ...], None]] = [] for topology in sorted(transition_groups): group = transition_groups[topology] @@ -513,10 +506,26 @@ def _collapse_graphs( FrozenTransition( topology, states={ - i: tuple(sorted(particles, key=lambda p: p.name)) + i: tuple(sorted(particles, key=_sorting_key)) for i, particles in group.states.items() }, interactions=group.interactions, ) ) return collected_graphs + + +def _strip_properties(state: Any) -> Any: + if isinstance(state, State): + return state.particle + if isinstance(state, str): + return state + return state + + +def _sorting_key(obj: Any) -> Any: + if isinstance(obj, State): + return obj.particle.name + if isinstance(obj, str): + return obj.lower() + return obj