Skip to content

Commit

Permalink
fix: added + and * overload type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
denehoffman committed Jun 5, 2024
1 parent bc47cf1 commit d3dda39
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
12 changes: 0 additions & 12 deletions py-rustitude/rustitude/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ from typing import Self, overload

__version__: str


class Parameter:
amplitude: str
name: str
Expand All @@ -14,7 +13,6 @@ class Parameter:

def __init__(self, amplitude: str, name: str, index: int) -> None: ...


class AmpOp:
def print_tree(self): ...
def real(self) -> Self: ...
Expand All @@ -25,26 +23,22 @@ class AmpOp:
@overload
def __mul__(self, other: CohSum) -> CohSum: ...


def Scalar(name: str) -> AmpOp: ...
def CScalar(name: str) -> AmpOp: ...
def PCScalar(name: str) -> AmpOp: ...
def PiecewiseM(name: str, bins: int, range: tuple[float, float]) -> AmpOp: ...


class Amplitude:
name: str
active: bool
cache_position: int
parameter_index_start: int


class CohSum:
def __init__(self, terms: list[AmpOp]) -> None: ...
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: AmpOp) -> CohSum: ...


class Model:
cohsums: list[CohSum]
amplitudes: list[Amplitude]
Expand All @@ -68,7 +62,6 @@ class Model:
def activate(self, amplitude: str) -> None: ...
def deactivate(self, amplitude: str) -> None: ...


class FourMomentum:
e: float
px: float
Expand All @@ -86,7 +79,6 @@ class FourMomentum:
def __add__(self, other: FourMomentum) -> FourMomentum: ...
def __sub__(self, other: FourMomentum) -> FourMomentum: ...


class Event:
index: int
weight: float
Expand All @@ -95,7 +87,6 @@ class Event:
daughter_p4s: list[FourMomentum]
eps: list[float]


class Dataset:
events: list[Event]
weights: list[float]
Expand Down Expand Up @@ -124,12 +115,10 @@ class Dataset:
@staticmethod
def from_root(path: str) -> Dataset: ...


def open(
file_name: str | Path, tree_name: str | None = None, *, pol_in_beam: bool = False
) -> Dataset: ... # noqa: A001


class Manager:
root: AmpOp
amplitudes: list[Amplitude]
Expand All @@ -150,7 +139,6 @@ class Manager:
def activate(self, amplitude: str) -> None: ...
def deactivate(self, amplitude: str) -> None: ...


class ExtendedLogLikelihood:
root: AmpOp
amplitudes: list[Amplitude]
Expand Down
6 changes: 6 additions & 0 deletions py-rustitude/src/amplitude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ impl CohSum {
fn print_tree(&self) {
self.0.print_tree()
}
fn __add__(&self, other: Self) -> CohSum {
(self.0.clone() + other.0).into()
}
fn __mul__(&self, other: AmpOp) -> Self {
(self.0.clone() * other.0).into()
}
}

impl From<rust::CohSum> for CohSum {
Expand Down

0 comments on commit d3dda39

Please sign in to comment.