Skip to content

Commit

Permalink
feat: add min and max to backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mstechly committed Dec 19, 2024
1 parent 0228662 commit 2b2b724
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/bartiq/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ class BartiqPreprocessingError(Exception):
"""Raised for errors during Bartiq function pre-processing."""


class BartiqPostprocessingError(Exception):
"""Raised for errors during Bartiq function post-processing."""


class BartiqCompilationError(Exception):
"""Raised for errors during Bartiq function compilation."""
4 changes: 2 additions & 2 deletions src/bartiq/repetitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ class CustomSequence(Generic[T]):
iterator_symbol: T

def get_sum(self, expr: TExpr[T], count: TExpr[T], backend: SymbolicBackend[T]) -> TExpr[T]:
return backend.sum(self.term_expression * expr, self.iterator_symbol, 0, count - 1)
return backend.sequence_sum(self.term_expression * expr, self.iterator_symbol, 0, count - 1)

def get_prod(self, expr: TExpr[T], count: TExpr[T], backend: SymbolicBackend[T]) -> TExpr[T]:
return backend.prod(self.term_expression * expr, self.iterator_symbol, 0, count - 1)
return backend.sequence_prod(self.term_expression * expr, self.iterator_symbol, 0, count - 1)

def substitute_symbols(
self, inputs: dict[str, TExpr[T]], backend: SymbolicBackend[T], functions_map=None
Expand Down
10 changes: 8 additions & 2 deletions src/bartiq/symbolics/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,14 @@ def compare(self, lhs: TExpr[T], rhs: TExpr[T]) -> ComparisonResult:
def func(self, func_name: str) -> Callable[..., TExpr[T]]:
"""Obtain an implementation of a function with given name."""

def sum(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
def min(self, *args):
"""Returns a smallest value from given args."""

def max(self, *args):
"""Returns a biggest value from given args."""

def sequence_sum(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
"""Express a sum of terms expressed using `iterator_symbol`."""

def prod(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
def sequence_prod(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
"""Express a product of terms expressed using `iterator_symbol`."""
12 changes: 10 additions & 2 deletions src/bartiq/symbolics/sympy_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,19 @@ def func(self, func_name: str) -> Callable[..., TExpr[Expr]]:
except KeyError:
return sympy.Function(func_name)

def sum(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
def min(self, *args):
"""Returns a smallest value from given args."""
return sympy.Min(*args)

def max(self, *args):
"""Returns a biggest value from given args."""
return sympy.Max(*args)

def sequence_sum(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
"""Express a sum of terms expressed using `iterator_symbol`."""
return sympy.Sum(term, (iterator_symbol, start, end))

def prod(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
def sequence_prod(self, term: TExpr[T], iterator_symbol: TExpr[T], start: TExpr[T], end: TExpr[T]) -> TExpr[T]:
"""Express a product of terms expressed using `iterator_symbol`."""
return sympy.Product(term, (iterator_symbol, start, end))

Expand Down
1 change: 1 addition & 0 deletions src/bartiq/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


T = TypeVar("T")
# NOTE: Actually, it should be `Routine[T]` and `CompiledRoutine[T]`, but such syntax is not currently supported.
AnyRoutine = TypeVar("AnyRoutine", Routine, CompiledRoutine)


Expand Down
17 changes: 16 additions & 1 deletion tests/symbolics/test_sympy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
"""

import pytest
from sympy import E, cos, exp, pi, sin, sqrt, sympify
from sympy import E
from sympy import Max as sympy_max
from sympy import Min as sympy_min
from sympy import cos, exp, pi, sin, sqrt, symbols, sympify

from bartiq.errors import BartiqCompilationError
from bartiq.symbolics import sympy_backend
Expand Down Expand Up @@ -163,3 +166,15 @@ def g(x):

assert result_1 == 10
assert result_2 == 10


def test_min_max_works_for_numerical_values(backend):
values = [-5, 0, 1, 23.4]
assert backend.min(*values) == min(*values)
assert backend.max(*values) == max(*values)


def test_min_max_works_for_symbols(backend):
values = symbols("a, b, c")
assert backend.min(*values) == sympy_min(*values)
assert backend.max(*values) == sympy_max(*values)

0 comments on commit 2b2b724

Please sign in to comment.