Skip to content

Commit

Permalink
FIX: show only selected rules in DOT rendering (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Jul 8, 2023
1 parent fb1485a commit d905f8f
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions src/qrules/io/_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
import functools
import re
from collections import abc
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from inspect import isfunction
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)

import attrs

from qrules.argument_handling import Rule
from qrules.combinatorics import InitialFacts
from qrules.particle import Particle, ParticleCollection, ParticleWithSpin
from qrules.quantum_numbers import InteractionProperties, _to_fraction
Expand Down Expand Up @@ -419,31 +432,45 @@ def __node_label(node_prop: Union[InteractionProperties, NodeSettings]) -> str:
def __render_settings(settings: Union[EdgeSettings, NodeSettings]) -> str:
output = ""
if settings.rule_priorities:
output += "RULE PRIORITIES\n"
rule_names = (
f"{item[0].__name__} - {item[1]}" # type: ignore[union-attr]
for item in settings.rule_priorities.items()
output += "RULES\n"
rule_descriptions = (
f"{__render_rule(rule)} - {__get_priority(rule, settings.rule_priorities)}"
for rule in settings.conservation_rules
)
sorted_names = sorted(rule_names, key=__extract_priority, reverse=True)
sorted_names = sorted(rule_descriptions, key=__extract_priority, reverse=True)
output += "\n".join(sorted_names)
if settings.qn_domains:
if output:
output += "\n"
domains = sorted(
f"{item[0].__name__}{item[1]}" for item in settings.qn_domains.items()
f"{qn.__name__}{domain}" for qn, domain in settings.qn_domains.items()
)
output += "DOMAINS\n"
output += "\n".join(domains)
return output


def __extract_priority(description: str) -> int:
matches = re.match(r".* \- ([0-9]+)$", description)
def __get_priority(rule: Any, rule_priorities: Dict[Any, int]) -> Union[int, str]:
rule_type = __get_type(rule)
return rule_priorities.get(rule_type, "NA")


def __render_rule(rule: Rule) -> str:
return __get_type(rule).__name__


def __get_type(rule: Rule) -> Type[Rule]:
if isfunction(rule):
return rule # type: ignore[return-value]
return type(rule)


def __extract_priority(description: str) -> str:
matches = re.match(r".* \- ([0-9]+|NA)$", description)
if matches is None:
msg = f"{description} does not contain a priority number"
raise ValueError(msg)
priority = matches[1]
return int(priority)
return matches[1]


def _get_particle_graphs(
Expand Down

0 comments on commit d905f8f

Please sign in to comment.