Skip to content

Commit

Permalink
Fix sympy symbol name clashes (#202)
Browse files Browse the repository at this point in the history
Make sure symbols like `N`, `beta`, ... are used as scalar symbols, and not as sympy functions with the same name.

See also ICB-DCM/pyPESTO#1048
  • Loading branch information
dweindl authored May 8, 2023
1 parent bd0176f commit dc0be75
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 9 deletions.
5 changes: 3 additions & 2 deletions petab/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import numpy as np
import pandas as pd
import petab
import sympy
from sympy.abc import _clash

import petab
from .C import *

__all__ = ['calculate_residuals', 'calculate_residuals_for_table',
Expand Down Expand Up @@ -138,7 +139,7 @@ def get_symbolic_noise_formulas(observable_df) -> Dict[str, sympy.Expr]:
if NOISE_FORMULA not in observable_df.columns:
noise_formula = None
else:
noise_formula = sympy.sympify(row.noiseFormula)
noise_formula = sympy.sympify(row.noiseFormula, locals=_clash)
noise_formulas[observable_id] = noise_formula
return noise_formulas

Expand Down
11 changes: 6 additions & 5 deletions petab/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
import logging
import numbers
import re
from typing import Optional, Iterable, Any
from collections import Counter
from typing import Any, Iterable, Optional

import numpy as np
import pandas as pd
import sympy as sp
from sympy.abc import _clash

import petab
from . import (core, parameters, measurements)
from .models import Model
from . import (core, measurements, parameters)
from .C import * # noqa: F403
from .models import Model

logger = logging.getLogger(__name__)
__all__ = ['assert_all_parameters_present_in_parameter_df',
Expand Down Expand Up @@ -287,15 +288,15 @@ def check_observable_df(observable_df: pd.DataFrame) -> None:
for row in observable_df.itertuples():
obs = getattr(row, OBSERVABLE_FORMULA)
try:
sp.sympify(obs)
sp.sympify(obs, locals=_clash)
except sp.SympifyError as e:
raise AssertionError(
f"Cannot parse expression '{obs}' "
f"for observable {row.Index}: {e}") from e

noise = getattr(row, NOISE_FORMULA)
try:
sympified_noise = sp.sympify(noise)
sympified_noise = sp.sympify(noise, locals=_clash)
if sympified_noise is None \
or (sympified_noise.is_Number
and not sympified_noise.is_finite):
Expand Down
5 changes: 3 additions & 2 deletions petab/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import re
from collections import OrderedDict
from pathlib import Path
from typing import List, Union, Literal
from typing import List, Literal, Union

import pandas as pd
import sympy as sp
from sympy.abc import _clash

from . import core, lint
from .C import * # noqa: F403
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_output_parameters(
output_parameters = OrderedDict()

for formula in formulas:
free_syms = sorted(sp.sympify(formula).free_symbols,
free_syms = sorted(sp.sympify(formula, locals=_clash).free_symbols,
key=lambda symbol: symbol.name)
for free_sym in free_syms:
sym = str(free_sym)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ def test_get_output_parameters():

assert output_parameters == ['offset', 'scaling']

# test sympy-special symbols (e.g. N, beta, ...)
# see https://github.com/ICB-DCM/pyPESTO/issues/1048
observable_df = pd.DataFrame(data={
OBSERVABLE_ID: ['observable_1'],
OBSERVABLE_NAME: ['observable name 1'],
OBSERVABLE_FORMULA: ['observable_1 * N + beta'],
NOISE_FORMULA: [1],
}).set_index(OBSERVABLE_ID)

output_parameters = petab.get_output_parameters(
observable_df, SbmlModel(sbml_model=ss_model.model))

assert output_parameters == ['N', 'beta']


def test_get_formula_placeholders():
"""Test get_formula_placeholders"""
Expand Down

0 comments on commit dc0be75

Please sign in to comment.