Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: implement quantum problem set filter #278

Closed
wants to merge 13 commits into from
Closed
Prev Previous commit
Next Next commit
added pid&spin_projection -> non-zero results
  • Loading branch information
grayson-helmholz committed Aug 21, 2024
commit 6fbf83b7af916129c41515790fcc7f2a18fec90a
274 changes: 51 additions & 223 deletions tests/unit/test_solving.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable, Union

import attrs
import pytest
@@ -22,11 +21,38 @@
from qrules.topology import MutableTransition

if TYPE_CHECKING:
from qrules.argument_handling import (
GraphEdgePropertyMap,
GraphNodePropertyMap,
Rule,
)
from qrules.argument_handling import Rule


EdgeQuantumNumberTypes = Union[
type[EdgeQuantumNumbers.pid],
type[EdgeQuantumNumbers.mass],
type[EdgeQuantumNumbers.width],
type[EdgeQuantumNumbers.spin_magnitude],
type[EdgeQuantumNumbers.spin_projection],
type[EdgeQuantumNumbers.charge],
type[EdgeQuantumNumbers.isospin_magnitude],
type[EdgeQuantumNumbers.isospin_projection],
type[EdgeQuantumNumbers.strangeness],
type[EdgeQuantumNumbers.charmness],
type[EdgeQuantumNumbers.bottomness],
type[EdgeQuantumNumbers.topness],
type[EdgeQuantumNumbers.baryon_number],
type[EdgeQuantumNumbers.electron_lepton_number],
type[EdgeQuantumNumbers.muon_lepton_number],
type[EdgeQuantumNumbers.tau_lepton_number],
type[EdgeQuantumNumbers.parity],
type[EdgeQuantumNumbers.c_parity],
type[EdgeQuantumNumbers.g_parity],
]

NodeQuantumNumberTypes = Union[
type[NodeQuantumNumbers.l_magnitude],
type[NodeQuantumNumbers.l_projection],
type[NodeQuantumNumbers.s_magnitude],
type[NodeQuantumNumbers.s_projection],
type[NodeQuantumNumbers.parity_prefactor],
]


def test_solve(
@@ -52,7 +78,9 @@ def test_solve_with_filtered_quantum_number_problem_set(
c_parity_conservation,
},
edge_domains=(
EdgeQuantumNumbers.pid, # had to be added for c_parity_conservation to work
EdgeQuantumNumbers.spin_magnitude,
EdgeQuantumNumbers.spin_projection, # had to be added for spin_magnitude_conservation to work
EdgeQuantumNumbers.parity,
EdgeQuantumNumbers.c_parity,
),
@@ -61,7 +89,9 @@ def test_solve_with_filtered_quantum_number_problem_set(
new_quantum_number_problem_set = filter_quantum_number_problem_set_properties(
new_quantum_number_problem_set,
edge_properties=(
EdgeQuantumNumbers.pid, # had to be added for c_parity_conservation to work
EdgeQuantumNumbers.spin_magnitude,
EdgeQuantumNumbers.spin_projection, # had to be added for spin_magnitude_conservation to work
EdgeQuantumNumbers.parity,
EdgeQuantumNumbers.c_parity,
),
@@ -75,154 +105,36 @@ def test_solve_with_filtered_quantum_number_problem_set(
assert len(result.solutions) != 0
redeboer marked this conversation as resolved.
Show resolved Hide resolved


def remove_quantum_number_problem_set_settings(
quantum_number_problem_set: QNProblemSet,
edge_rules_to_be_removed: set[GraphElementRule],
node_rules_to_be_removed: set[Rule],
edge_domains_to_be_removed: tuple[Any, ...],
node_domains_to_be_removed: tuple[Any, ...],
) -> QNProblemSet:
old_edge_settings = quantum_number_problem_set.solving_settings.states
old_node_settings = quantum_number_problem_set.solving_settings.interactions
new_edge_settings = {
edge_id: EdgeSettings(
conservation_rules=edge_setting.conservation_rules
- edge_rules_to_be_removed,
rule_priorities=edge_setting.rule_priorities,
qn_domains={
key: val
for key, val in edge_setting.qn_domains.items()
if key not in edge_domains_to_be_removed
},
)
for edge_id, edge_setting in old_edge_settings.items()
}
new_node_settings = {
node_id: NodeSettings(
conservation_rules=node_setting.conservation_rules
- node_rules_to_be_removed,
rule_priorities=node_setting.rule_priorities,
qn_domains={
key: val
for key, val in node_setting.qn_domains.items()
if key not in node_domains_to_be_removed
},
)
for node_id, node_setting in old_node_settings.items()
}
new_mutable_transition = MutableTransition(
topology=quantum_number_problem_set.solving_settings.topology,
states=new_edge_settings,
interactions=new_node_settings,
)
return attrs.evolve(
quantum_number_problem_set, solving_settings=new_mutable_transition
)


def remove_quantum_number_problem_set_properties(
quantum_number_problem_set: QNProblemSet,
edge_properties_to_be_removed: tuple[EdgeQuantumNumbers],
node_properties_to_be_removed: tuple[NodeQuantumNumbers],
) -> QNProblemSet:
old_edge_properties = quantum_number_problem_set.initial_facts.states
old_node_properties = quantum_number_problem_set.initial_facts.interactions
new_edge_properties = {
edge_id: {
edge_quantum_number: scalar
for edge_quantum_number, scalar in graph_edge_property_map.items()
if edge_quantum_number not in edge_properties_to_be_removed
}
for edge_id, graph_edge_property_map in old_edge_properties.items()
}
new_node_properties = {
node_id: {
node_quantum_number: scalar
for node_quantum_number, scalar in graph_node_property_map.items()
if node_quantum_number not in node_properties_to_be_removed
}
for node_id, graph_node_property_map in old_node_properties.items()
}
new_mutable_transition = MutableTransition(
topology=quantum_number_problem_set.initial_facts.topology,
states=new_edge_properties,
interactions=new_node_properties,
)
return attrs.evolve(
quantum_number_problem_set, initial_facts=new_mutable_transition
)


def test_inner_dicts_unchanged(
quantum_number_problem_set: QNProblemSet,
) -> None:
old_inner_graph_edge_property_map = copy.deepcopy(
quantum_number_problem_set.initial_facts.states
)
old_inner_graph_node_property_map = copy.deepcopy(
quantum_number_problem_set.initial_facts.interactions
)
graph_edge_property_map = {
EdgeQuantumNumbers.spin_magnitude: 1,
EdgeQuantumNumbers.parity: -1,
EdgeQuantumNumbers.c_parity: 1,
}
graph_node_property_map = {
NodeQuantumNumbers.s_magnitude: 1,
NodeQuantumNumbers.l_magnitude: 0,
}
quantum_number_problem_set_with_new_properties(
quantum_number_problem_set, graph_edge_property_map, graph_node_property_map
)
assert (
old_inner_graph_edge_property_map
== quantum_number_problem_set.initial_facts.states
)
assert (
old_inner_graph_node_property_map
== quantum_number_problem_set.initial_facts.interactions
)


def filter_quantum_number_problem_set_settings(
quantum_number_problem_set: QNProblemSet,
edge_rules: set[GraphElementRule],
node_rules: set[Rule],
edge_domains: tuple[Any, ...],
node_domains: tuple[Any, ...],
keep_domains: bool = True,
edge_domains: Iterable[Any],
node_domains: Iterable[Any],
) -> QNProblemSet:
old_edge_settings = quantum_number_problem_set.solving_settings.states
old_node_settings = quantum_number_problem_set.solving_settings.interactions
new_edge_settings = {
edge_id: EdgeSettings(
conservation_rules=edge_rules,
rule_priorities=edge_setting.rule_priorities,
qn_domains=(
edge_setting.qn_domains
if keep_domains
else {
key: val
for key, val in edge_setting.qn_domains.items()
if key in edge_domains
}
),
qn_domains=({
key: val
for key, val in edge_setting.qn_domains.items()
if key in set(edge_domains)
}),
)
for edge_id, edge_setting in old_edge_settings.items()
}
new_node_settings = {
node_id: NodeSettings(
conservation_rules=node_rules,
rule_priorities=node_setting.rule_priorities,
qn_domains=(
node_setting.qn_domains
if keep_domains
else {
key: val
for key, val in node_setting.qn_domains.items()
if key in node_domains
}
),
qn_domains=({
key: val
for key, val in node_setting.qn_domains.items()
if key in set(node_domains)
}),
)
for node_id, node_setting in old_node_settings.items()
}
@@ -238,8 +150,8 @@ def filter_quantum_number_problem_set_settings(

def filter_quantum_number_problem_set_properties(
quantum_number_problem_set: QNProblemSet,
edge_properties: tuple[EdgeQuantumNumbers],
node_properties: tuple[NodeQuantumNumbers],
edge_properties: Iterable[EdgeQuantumNumberTypes],
node_properties: Iterable[NodeQuantumNumberTypes],
) -> QNProblemSet:
old_edge_properties = quantum_number_problem_set.initial_facts.states
old_node_properties = quantum_number_problem_set.initial_facts.interactions
@@ -269,90 +181,6 @@ def filter_quantum_number_problem_set_properties(
)


def quantum_number_problem_set_with_new_settings(
quantum_number_problem_set: QNProblemSet,
edge_rules: set[GraphElementRule],
node_rules: set[Rule],
edge_domains: dict[Any, list],
node_domains: dict[Any, list],
) -> QNProblemSet:
def qnp_with_new_rules(
quantum_number_problem_set: QNProblemSet,
edge_rules: set[GraphElementRule],
node_rules: set[Rule],
) -> QNProblemSet:
old_settings = quantum_number_problem_set.solving_settings
new_settings = attrs.evolve(
quantum_number_problem_set.solving_settings,
states={
edge_id: attrs.evolve(
setting, conservation_rules=setting.conservation_rules & edge_rules
)
for edge_id, setting in old_settings.states.items()
},
interactions={
node_id: attrs.evolve(
setting, conservation_rules=setting.conservation_rules & node_rules
)
for node_id, setting in old_settings.interactions.items()
},
)
return attrs.evolve(quantum_number_problem_set, solving_settings=new_settings)

def qnp_with_new_domains(
quantum_number_problem_set: QNProblemSet,
edge_domains: dict[Any, list],
node_domains: dict[Any, list],
) -> QNProblemSet:
old_settings = quantum_number_problem_set.solving_settings
new_settings = attrs.evolve(
old_settings,
states={
edge_id: attrs.evolve(setting, qn_domains=edge_domains)
for edge_id, setting in old_settings.states.items()
},
interactions={
node_id: attrs.evolve(setting, qn_domains=node_domains)
for node_id, setting in old_settings.interactions.items()
},
)
return attrs.evolve(quantum_number_problem_set, solving_settings=new_settings)

return qnp_with_new_rules(
qnp_with_new_domains(quantum_number_problem_set, edge_domains, node_domains),
edge_rules,
node_rules,
)


def quantum_number_problem_set_with_new_properties(
quantum_number_problem_set: QNProblemSet,
graph_edge_property_map: GraphEdgePropertyMap,
graph_node_property_map: GraphNodePropertyMap,
) -> QNProblemSet:
old_facts = quantum_number_problem_set.initial_facts
new_facts = attrs.evolve(
old_facts,
states={
node_id: {
key: val
for key, val in prop_map.items()
if key in graph_edge_property_map
}
for node_id, prop_map in old_facts.states.items()
},
interactions={
node_id: {
key: val
for key, val in prop_map.items()
if key in graph_node_property_map
}
for node_id, prop_map in old_facts.interactions.items()
},
)
return attrs.evolve(quantum_number_problem_set, initial_facts=new_facts)


@pytest.fixture(scope="session")
def all_particles():
return [
Loading