Skip to content

Commit

Permalink
ENH: specify allowed formalisms with Literal (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Mar 3, 2024
1 parent fb132e2 commit 9640f1d
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 21 deletions.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def create_constraints_inventory() -> None:
api_target_substitutions: dict[str, str | tuple[str, str]] = {
"EdgeType": "typing.TypeVar",
"GraphEdgePropertyMap": ("obj", "qrules.argument_handling.GraphEdgePropertyMap"),
"GraphNodePropertyMap": ("obj", "qrules.argument_handling.GraphNodePropertyMap"),
"GraphElementProperties": ("obj", "qrules.solving.GraphElementProperties"),
"GraphNodePropertyMap": ("obj", "qrules.argument_handling.GraphNodePropertyMap"),
"GraphSettings": ("obj", "qrules.solving.GraphSettings"),
"InitialFacts": ("obj", "qrules.combinatorics.InitialFacts"),
"NewEdgeType": "typing.TypeVar",
Expand All @@ -75,6 +75,7 @@ def create_constraints_inventory() -> None:
"ParticleWithSpin": ("obj", "qrules.particle.ParticleWithSpin"),
"qrules.topology.EdgeType": "typing.TypeVar",
"qrules.topology.NodeType": "typing.TypeVar",
"SpinFormalism": ("obj", "qrules.transition.SpinFormalism"),
"StateDefinition": ("obj", "qrules.combinatorics.StateDefinition"),
"TypeAlias": "typing.TypeAlias",
"typing_extensions.TypeAlias": "typing.TypeAlias",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"particle",
"python-constraint",
"tqdm >=4.24.0", # autonotebook
'typing-extensions; python_version <"3.10.0"', # typing.TypeAlias
'typing-extensions; python_version <"3.10.0"', # Literal, Protocol, TypeAlias
]
description = "Rule-based particle reaction problem solver on a quantum number level"
dynamic = ["version"]
Expand Down
3 changes: 2 additions & 1 deletion src/qrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
EdgeSettings,
ProblemSet,
ReactionInfo,
SpinFormalism,
StateTransitionManager,
)

Expand Down Expand Up @@ -264,7 +265,7 @@ def generate_transitions( # noqa: PLR0917
final_state: Sequence[StateDefinition],
allowed_intermediate_particles: list[str] | None = None,
allowed_interaction_types: str | Iterable[str] | None = None,
formalism: str = "canonical-helicity",
formalism: SpinFormalism = "canonical-helicity",
particle_db: ParticleCollection | None = None,
mass_conservation_factor: float | None = 3.0,
max_angular_momentum: int = 2,
Expand Down
3 changes: 2 additions & 1 deletion src/qrules/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

if TYPE_CHECKING:
from qrules.particle import Particle, ParticleCollection
from qrules.transition import SpinFormalism

__QRULES_PATH = dirname(realpath(__file__))
ADDITIONAL_PARTICLES_DEFINITIONS_PATH: str = join(
Expand Down Expand Up @@ -118,7 +119,7 @@ def from_str(description: str) -> InteractionType:


def create_interaction_settings( # noqa: PLR0917
formalism: str,
formalism: SpinFormalism,
particle_db: ParticleCollection,
nbody_topology: bool = False,
mass_conservation_factor: float | None = 3.0,
Expand Down
37 changes: 25 additions & 12 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import attrs
from attrs import define, field, frozen
from attrs.validators import instance_of
from attrs.validators import in_, instance_of
from tqdm.auto import tqdm

from qrules._implementers import implement_pretty_repr
Expand Down Expand Up @@ -74,6 +74,11 @@
create_n_body_topology,
)

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
Expand All @@ -83,6 +88,19 @@

_LOGGER = logging.getLogger(__name__)

SpinFormalism = Literal[
"helicity",
"canonical-helicity",
"canonical",
]
"""Name for the spin formalism to be used.
The options :code:`"helicity"`, :code:`"canonical-helicity"`, and :code:`"canonical"`
are all used for the helicity formalism, but :code:`"canonical-helicity"` and
:code:`"canonical"` generate angular momentum and coupled spins as well on the
interaction nodes.
"""


class SolvingMode(Enum):
"""Types of modes for solving."""
Expand Down Expand Up @@ -226,7 +244,7 @@ def __init__( # noqa: C901, PLR0912, PLR0917
InteractionType, tuple[EdgeSettings, NodeSettings]
]
| None = None,
formalism: str = "helicity",
formalism: SpinFormalism = "helicity",
topology_building: str = "isobar",
solving_mode: SolvingMode = SolvingMode.FAST,
reload_pdg: bool = False,
Expand All @@ -240,18 +258,13 @@ def __init__( # noqa: C901, PLR0912, PLR0917
self.__number_of_threads = NumberOfThreads.get()
if interaction_type_settings is None:
interaction_type_settings = {}
allowed_formalisms = [
"helicity",
"canonical-helicity",
"canonical",
]
if formalism not in allowed_formalisms:
if formalism not in set(SpinFormalism.__args__): # type: ignore[attr-defined]
msg = (
f'Formalism "{formalism}" not implemented. Use one of'
f" {allowed_formalisms} instead."
f" {', '.join(SpinFormalism.__args__)} instead." # type: ignore[attr-defined]
)
raise NotImplementedError(msg)
self.__formalism = str(formalism)
self.__formalism = formalism
self.__particles = ParticleCollection()
if particle_db is not None:
self.__particles = particle_db
Expand Down Expand Up @@ -343,7 +356,7 @@ def set_allowed_intermediate_particles(
self.__intermediate_particle_filters = selected_particles.names

@property
def formalism(self) -> str:
def formalism(self) -> SpinFormalism:
return self.__formalism

def add_final_state_grouping(self, fs_group: list[str] | list[list[str]]) -> None:
Expand Down Expand Up @@ -744,7 +757,7 @@ class ReactionInfo:
"""Ordered collection of `StateTransition` instances."""

transitions: tuple[StateTransition, ...] = field(converter=_sort_tuple)
formalism: str = field(validator=instance_of(str))
formalism: SpinFormalism = field(validator=in_(SpinFormalism.__args__)) # type: ignore[attr-defined]

initial_state: FrozenDict[int, Particle] = field(init=False, repr=False, eq=False)
final_state: FrozenDict[int, Particle] = field(init=False, repr=False, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion tests/channels/test_psi2s_to_eta_k_kstar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import qrules
from qrules.particle import ParticleCollection
from qrules.transition import SpinFormalism


@pytest.mark.parametrize("formalism", ["helicity", "canonical-helicity"])
Expand All @@ -15,7 +16,7 @@
["h(1)(1415)", "omega(1650)"],
],
)
def test_resonances(formalism, resonances, modified_pdg):
def test_resonances(formalism: SpinFormalism, resonances, modified_pdg):
reaction = qrules.generate_transitions(
initial_state=("psi(2S)", [+1, -1]),
final_state=["eta", "K-", "K*(892)+"],
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import TYPE_CHECKING

import pytest
from _pytest.fixtures import SubRequest
Expand All @@ -7,12 +8,15 @@
from qrules import ReactionInfo
from qrules.topology import Edge, Topology

if TYPE_CHECKING:
from qrules.transition import SpinFormalism

logging.basicConfig(level=logging.ERROR)


@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"])
def reaction(request: SubRequest) -> ReactionInfo:
formalism: str = request.param
formalism: SpinFormalism = request.param
return qrules.generate_transitions(
initial_state=[("J/psi(1S)", [-1, 1])],
final_state=["gamma", "pi0", "pi0"],
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/io/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
create_isobar_topologies,
create_n_body_topology,
)
from qrules.transition import ReactionInfo
from qrules.transition import ReactionInfo, SpinFormalism


def test_asdot(reaction: ReactionInfo):
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_asdot_no_label_overwriting(reaction: ReactionInfo):
"formalism",
["canonical", "canonical-helicity", "helicity"],
)
def test_asdot_problemset(formalism: str):
def test_asdot_problemset(formalism: SpinFormalism):
stm = qrules.StateTransitionManager(
initial_state=[("J/psi(1S)", [+1])],
final_state=["gamma", "pi0", "pi0"],
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_int_domain,
create_interaction_settings,
)
from qrules.transition import SpinFormalism


class TestInteractionType:
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_create_interaction_settings(
particle_database: ParticleCollection,
interaction_type: InteractionType,
nbody_topology: bool,
formalism: str,
formalism: SpinFormalism,
):
settings = create_interaction_settings(
formalism,
Expand Down

0 comments on commit 9640f1d

Please sign in to comment.