From d905f8f2910f19ec944f0eee4d2f6d8e175d7eff Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sat, 8 Jul 2023 22:23:29 +0200 Subject: [PATCH] FIX: show only selected rules in DOT rendering (#225) --- src/qrules/io/_dot.py | 49 +++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 7604aae6..05bc8d52 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -6,10 +6,23 @@ import functools import re from collections import abc -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from inspect import isfunction +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Type, + Union, +) import attrs +from qrules.argument_handling import Rule from qrules.combinatorics import InitialFacts from qrules.particle import Particle, ParticleCollection, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction @@ -419,31 +432,45 @@ def __node_label(node_prop: Union[InteractionProperties, NodeSettings]) -> str: def __render_settings(settings: Union[EdgeSettings, NodeSettings]) -> str: output = "" if settings.rule_priorities: - output += "RULE PRIORITIES\n" - rule_names = ( - f"{item[0].__name__} - {item[1]}" # type: ignore[union-attr] - for item in settings.rule_priorities.items() + output += "RULES\n" + rule_descriptions = ( + f"{__render_rule(rule)} - {__get_priority(rule, settings.rule_priorities)}" + for rule in settings.conservation_rules ) - sorted_names = sorted(rule_names, key=__extract_priority, reverse=True) + sorted_names = sorted(rule_descriptions, key=__extract_priority, reverse=True) output += "\n".join(sorted_names) if settings.qn_domains: if output: output += "\n" domains = sorted( - f"{item[0].__name__} ∊ {item[1]}" for item in settings.qn_domains.items() + f"{qn.__name__} ∊ {domain}" for qn, domain in settings.qn_domains.items() ) output += "DOMAINS\n" output += "\n".join(domains) return output -def __extract_priority(description: str) -> int: - matches = re.match(r".* \- ([0-9]+)$", description) +def __get_priority(rule: Any, rule_priorities: Dict[Any, int]) -> Union[int, str]: + rule_type = __get_type(rule) + return rule_priorities.get(rule_type, "NA") + + +def __render_rule(rule: Rule) -> str: + return __get_type(rule).__name__ + + +def __get_type(rule: Rule) -> Type[Rule]: + if isfunction(rule): + return rule # type: ignore[return-value] + return type(rule) + + +def __extract_priority(description: str) -> str: + matches = re.match(r".* \- ([0-9]+|NA)$", description) if matches is None: msg = f"{description} does not contain a priority number" raise ValueError(msg) - priority = matches[1] - return int(priority) + return matches[1] def _get_particle_graphs(