From d3dda39330173a88a427e72f50b4e6b25e39b001 Mon Sep 17 00:00:00 2001 From: denehoffman Date: Wed, 5 Jun 2024 01:44:43 -0400 Subject: [PATCH] fix: added + and * overload type hints --- py-rustitude/rustitude/__init__.pyi | 12 ------------ py-rustitude/src/amplitude.rs | 6 ++++++ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/py-rustitude/rustitude/__init__.pyi b/py-rustitude/rustitude/__init__.pyi index 2958ff3..18fd7c5 100644 --- a/py-rustitude/rustitude/__init__.pyi +++ b/py-rustitude/rustitude/__init__.pyi @@ -3,7 +3,6 @@ from typing import Self, overload __version__: str - class Parameter: amplitude: str name: str @@ -14,7 +13,6 @@ class Parameter: def __init__(self, amplitude: str, name: str, index: int) -> None: ... - class AmpOp: def print_tree(self): ... def real(self) -> Self: ... @@ -25,26 +23,22 @@ class AmpOp: @overload def __mul__(self, other: CohSum) -> CohSum: ... - def Scalar(name: str) -> AmpOp: ... def CScalar(name: str) -> AmpOp: ... def PCScalar(name: str) -> AmpOp: ... def PiecewiseM(name: str, bins: int, range: tuple[float, float]) -> AmpOp: ... - class Amplitude: name: str active: bool cache_position: int parameter_index_start: int - class CohSum: def __init__(self, terms: list[AmpOp]) -> None: ... def __add__(self, other: Self) -> Self: ... def __mul__(self, other: AmpOp) -> CohSum: ... - class Model: cohsums: list[CohSum] amplitudes: list[Amplitude] @@ -68,7 +62,6 @@ class Model: def activate(self, amplitude: str) -> None: ... def deactivate(self, amplitude: str) -> None: ... - class FourMomentum: e: float px: float @@ -86,7 +79,6 @@ class FourMomentum: def __add__(self, other: FourMomentum) -> FourMomentum: ... def __sub__(self, other: FourMomentum) -> FourMomentum: ... - class Event: index: int weight: float @@ -95,7 +87,6 @@ class Event: daughter_p4s: list[FourMomentum] eps: list[float] - class Dataset: events: list[Event] weights: list[float] @@ -124,12 +115,10 @@ class Dataset: @staticmethod def from_root(path: str) -> Dataset: ... - def open( file_name: str | Path, tree_name: str | None = None, *, pol_in_beam: bool = False ) -> Dataset: ... # noqa: A001 - class Manager: root: AmpOp amplitudes: list[Amplitude] @@ -150,7 +139,6 @@ class Manager: def activate(self, amplitude: str) -> None: ... def deactivate(self, amplitude: str) -> None: ... - class ExtendedLogLikelihood: root: AmpOp amplitudes: list[Amplitude] diff --git a/py-rustitude/src/amplitude.rs b/py-rustitude/src/amplitude.rs index e784d7b..b98a771 100644 --- a/py-rustitude/src/amplitude.rs +++ b/py-rustitude/src/amplitude.rs @@ -103,6 +103,12 @@ impl CohSum { fn print_tree(&self) { self.0.print_tree() } + fn __add__(&self, other: Self) -> CohSum { + (self.0.clone() + other.0).into() + } + fn __mul__(&self, other: AmpOp) -> Self { + (self.0.clone() * other.0).into() + } } impl From for CohSum {