Skip to content

Commit

Permalink
perf: add caching for value_of (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
mstechly authored Dec 12, 2024
1 parent f5c27e8 commit cf5ce14
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions src/bartiq/symbolics/sympy_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping
from functools import singledispatchmethod
from functools import lru_cache, singledispatchmethod
from typing import Callable, Concatenate, ParamSpec, TypeVar

import sympy
Expand Down Expand Up @@ -113,6 +113,25 @@ def _eval_wrapper(cls, *args, **kwargs):
return sympy_func


@lru_cache
def _value_of(expr: Expr) -> Number | None:
"""Compute a numerical value of an expression, return None if it's not possible."""
try:
value = N(expr).round(n=NUM_DIGITS_PRECISION)
except TypeError as e:
if str(e) == "Cannot round symbolic expression":
return None
else:
raise e

# Map to integer if possible
if int(value) == value or value.is_Float and value % 1 == 0:
value = int(value)
else:
value = float(value)
return value


class SympyBackend:

def __init__(self, parse_function: Callable[[str], Expr] = parse_to_sympy):
Expand Down Expand Up @@ -156,20 +175,7 @@ def reserved_functions(self) -> Iterable[str]:
@identity_for_numbers
def value_of(self, expr: Expr) -> Number | None:
"""Compute a numerical value of an expression, return None if it's not possible."""
try:
value = N(expr).round(n=NUM_DIGITS_PRECISION)
except TypeError as e:
if str(e) == "Cannot round symbolic expression":
return None
else:
raise e

# Map to integer if possible
if int(value) == value or value.is_Float and value % 1 == 0:
value = int(value)
else:
value = float(value)
return value
return _value_of(expr)

@identity_for_numbers
def substitute(
Expand Down

0 comments on commit cf5ce14

Please sign in to comment.