From 9640f1d1f52da66fe8d7ec7eaf575d420719b5f8 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sun, 3 Mar 2024 22:46:29 +0100 Subject: [PATCH] ENH: specify allowed formalisms with `Literal` (#253) --- docs/conf.py | 3 +- pyproject.toml | 2 +- src/qrules/__init__.py | 3 +- src/qrules/settings.py | 3 +- src/qrules/transition.py | 37 ++++++++++++++------- tests/channels/test_psi2s_to_eta_k_kstar.py | 3 +- tests/unit/conftest.py | 6 +++- tests/unit/io/test_dot.py | 4 +-- tests/unit/test_settings.py | 3 +- 9 files changed, 43 insertions(+), 21 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index dab953bf..199db8df 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -64,8 +64,8 @@ def create_constraints_inventory() -> None: api_target_substitutions: dict[str, str | tuple[str, str]] = { "EdgeType": "typing.TypeVar", "GraphEdgePropertyMap": ("obj", "qrules.argument_handling.GraphEdgePropertyMap"), - "GraphNodePropertyMap": ("obj", "qrules.argument_handling.GraphNodePropertyMap"), "GraphElementProperties": ("obj", "qrules.solving.GraphElementProperties"), + "GraphNodePropertyMap": ("obj", "qrules.argument_handling.GraphNodePropertyMap"), "GraphSettings": ("obj", "qrules.solving.GraphSettings"), "InitialFacts": ("obj", "qrules.combinatorics.InitialFacts"), "NewEdgeType": "typing.TypeVar", @@ -75,6 +75,7 @@ def create_constraints_inventory() -> None: "ParticleWithSpin": ("obj", "qrules.particle.ParticleWithSpin"), "qrules.topology.EdgeType": "typing.TypeVar", "qrules.topology.NodeType": "typing.TypeVar", + "SpinFormalism": ("obj", "qrules.transition.SpinFormalism"), "StateDefinition": ("obj", "qrules.combinatorics.StateDefinition"), "TypeAlias": "typing.TypeAlias", "typing_extensions.TypeAlias": "typing.TypeAlias", diff --git a/pyproject.toml b/pyproject.toml index 54c9e380..c85dff69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "particle", "python-constraint", "tqdm >=4.24.0", # autonotebook - 'typing-extensions; python_version <"3.10.0"', # typing.TypeAlias + 'typing-extensions; python_version <"3.10.0"', # Literal, Protocol, TypeAlias ] description = "Rule-based particle reaction problem solver on a quantum number level" dynamic = ["version"] diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 1cc2941f..9c291320 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -59,6 +59,7 @@ EdgeSettings, ProblemSet, ReactionInfo, + SpinFormalism, StateTransitionManager, ) @@ -264,7 +265,7 @@ def generate_transitions( # noqa: PLR0917 final_state: Sequence[StateDefinition], allowed_intermediate_particles: list[str] | None = None, allowed_interaction_types: str | Iterable[str] | None = None, - formalism: str = "canonical-helicity", + formalism: SpinFormalism = "canonical-helicity", particle_db: ParticleCollection | None = None, mass_conservation_factor: float | None = 3.0, max_angular_momentum: int = 2, diff --git a/src/qrules/settings.py b/src/qrules/settings.py index c83f28a5..1bf920d1 100644 --- a/src/qrules/settings.py +++ b/src/qrules/settings.py @@ -50,6 +50,7 @@ if TYPE_CHECKING: from qrules.particle import Particle, ParticleCollection + from qrules.transition import SpinFormalism __QRULES_PATH = dirname(realpath(__file__)) ADDITIONAL_PARTICLES_DEFINITIONS_PATH: str = join( @@ -118,7 +119,7 @@ def from_str(description: str) -> InteractionType: def create_interaction_settings( # noqa: PLR0917 - formalism: str, + formalism: SpinFormalism, particle_db: ParticleCollection, nbody_topology: bool = False, mass_conservation_factor: float | None = 3.0, diff --git a/src/qrules/transition.py b/src/qrules/transition.py index aa5491f5..9eca9f67 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -14,7 +14,7 @@ import attrs from attrs import define, field, frozen -from attrs.validators import instance_of +from attrs.validators import in_, instance_of from tqdm.auto import tqdm from qrules._implementers import implement_pretty_repr @@ -74,6 +74,11 @@ create_n_body_topology, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + if sys.version_info >= (3, 10): from typing import TypeAlias else: @@ -83,6 +88,19 @@ _LOGGER = logging.getLogger(__name__) +SpinFormalism = Literal[ + "helicity", + "canonical-helicity", + "canonical", +] +"""Name for the spin formalism to be used. + +The options :code:`"helicity"`, :code:`"canonical-helicity"`, and :code:`"canonical"` +are all used for the helicity formalism, but :code:`"canonical-helicity"` and +:code:`"canonical"` generate angular momentum and coupled spins as well on the +interaction nodes. +""" + class SolvingMode(Enum): """Types of modes for solving.""" @@ -226,7 +244,7 @@ def __init__( # noqa: C901, PLR0912, PLR0917 InteractionType, tuple[EdgeSettings, NodeSettings] ] | None = None, - formalism: str = "helicity", + formalism: SpinFormalism = "helicity", topology_building: str = "isobar", solving_mode: SolvingMode = SolvingMode.FAST, reload_pdg: bool = False, @@ -240,18 +258,13 @@ def __init__( # noqa: C901, PLR0912, PLR0917 self.__number_of_threads = NumberOfThreads.get() if interaction_type_settings is None: interaction_type_settings = {} - allowed_formalisms = [ - "helicity", - "canonical-helicity", - "canonical", - ] - if formalism not in allowed_formalisms: + if formalism not in set(SpinFormalism.__args__): # type: ignore[attr-defined] msg = ( f'Formalism "{formalism}" not implemented. Use one of' - f" {allowed_formalisms} instead." + f" {', '.join(SpinFormalism.__args__)} instead." # type: ignore[attr-defined] ) raise NotImplementedError(msg) - self.__formalism = str(formalism) + self.__formalism = formalism self.__particles = ParticleCollection() if particle_db is not None: self.__particles = particle_db @@ -343,7 +356,7 @@ def set_allowed_intermediate_particles( self.__intermediate_particle_filters = selected_particles.names @property - def formalism(self) -> str: + def formalism(self) -> SpinFormalism: return self.__formalism def add_final_state_grouping(self, fs_group: list[str] | list[list[str]]) -> None: @@ -744,7 +757,7 @@ class ReactionInfo: """Ordered collection of `StateTransition` instances.""" transitions: tuple[StateTransition, ...] = field(converter=_sort_tuple) - formalism: str = field(validator=instance_of(str)) + formalism: SpinFormalism = field(validator=in_(SpinFormalism.__args__)) # type: ignore[attr-defined] initial_state: FrozenDict[int, Particle] = field(init=False, repr=False, eq=False) final_state: FrozenDict[int, Particle] = field(init=False, repr=False, eq=False) diff --git a/tests/channels/test_psi2s_to_eta_k_kstar.py b/tests/channels/test_psi2s_to_eta_k_kstar.py index f7679f25..fcc1133c 100644 --- a/tests/channels/test_psi2s_to_eta_k_kstar.py +++ b/tests/channels/test_psi2s_to_eta_k_kstar.py @@ -4,6 +4,7 @@ import qrules from qrules.particle import ParticleCollection +from qrules.transition import SpinFormalism @pytest.mark.parametrize("formalism", ["helicity", "canonical-helicity"]) @@ -15,7 +16,7 @@ ["h(1)(1415)", "omega(1650)"], ], ) -def test_resonances(formalism, resonances, modified_pdg): +def test_resonances(formalism: SpinFormalism, resonances, modified_pdg): reaction = qrules.generate_transitions( initial_state=("psi(2S)", [+1, -1]), final_state=["eta", "K-", "K*(892)+"], diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index ebb5ad1b..a0d414e2 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,4 +1,5 @@ import logging +from typing import TYPE_CHECKING import pytest from _pytest.fixtures import SubRequest @@ -7,12 +8,15 @@ from qrules import ReactionInfo from qrules.topology import Edge, Topology +if TYPE_CHECKING: + from qrules.transition import SpinFormalism + logging.basicConfig(level=logging.ERROR) @pytest.fixture(scope="session", params=["canonical-helicity", "helicity"]) def reaction(request: SubRequest) -> ReactionInfo: - formalism: str = request.param + formalism: SpinFormalism = request.param return qrules.generate_transitions( initial_state=[("J/psi(1S)", [-1, 1])], final_state=["gamma", "pi0", "pi0"], diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index b65e6ce1..4f00f2bc 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -11,7 +11,7 @@ create_isobar_topologies, create_n_body_topology, ) -from qrules.transition import ReactionInfo +from qrules.transition import ReactionInfo, SpinFormalism def test_asdot(reaction: ReactionInfo): @@ -116,7 +116,7 @@ def test_asdot_no_label_overwriting(reaction: ReactionInfo): "formalism", ["canonical", "canonical-helicity", "helicity"], ) -def test_asdot_problemset(formalism: str): +def test_asdot_problemset(formalism: SpinFormalism): stm = qrules.StateTransitionManager( initial_state=[("J/psi(1S)", [+1])], final_state=["gamma", "pi0", "pi0"], diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index a8bb739e..10eb633d 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -9,6 +9,7 @@ _int_domain, create_interaction_settings, ) +from qrules.transition import SpinFormalism class TestInteractionType: @@ -60,7 +61,7 @@ def test_create_interaction_settings( particle_database: ParticleCollection, interaction_type: InteractionType, nbody_topology: bool, - formalism: str, + formalism: SpinFormalism, ): settings = create_interaction_settings( formalism,