Skip to content

Commit

Permalink
catch numexpr errors and try with eval() and also test the case that …
Browse files Browse the repository at this point in the history
…a module returns a np.array
  • Loading branch information
tdixon97 committed Jan 28, 2025
1 parent 113878b commit 8244792
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/lgdo/types/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,22 @@ def _make_lgdo(data):

# use numexpr if we are only dealing with numpy data types (and no global dictionary)
if not has_ak and modules is None:
out_data = ne.evaluate(
expr,
local_dict=(self_unwrap | parameters),
)

msg = f"...the result is {out_data!r}"
log.debug(msg)

# need to convert back to LGDO
# np.evaluate should always return a numpy thing?
return _make_lgdo(out_data)
try:
out_data = ne.evaluate(
expr,
local_dict=(self_unwrap | parameters),
)

msg = f"...the result is {out_data!r}"
log.debug(msg)

# need to convert back to LGDO
# np.evaluate should always return a numpy thing?
return _make_lgdo(out_data)

except Exception:
msg = f"Warning {expr} could not be evaluated with numexpr probably due to some not allowed characters, trying with eval()."
log.debug(msg)

# resort to good ol' eval()
globs = {"ak": ak, "np": np}
Expand Down
14 changes: 14 additions & 0 deletions tests/types/test_table_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dbetto
import hist
import numpy as np
import pytest
Expand Down Expand Up @@ -85,6 +86,19 @@ def test_eval_dependency():
res = obj.eval("lgdo.Array([1,2,3])", {}, modules={"lgdo": lgdo})
assert res == lgdo.Array([1, 2, 3])

# test with module returning np.array
assert obj.eval("np.sum(a)", {}, modules={"np": np}).value == np.int64(10)

# check bad type
with pytest.raises(RuntimeError):
obj.eval("hist.Hist()", modules={"hist": hist})

# check impossible numexpr can still run
assert np.allclose(
obj.eval(
"a*args.value",
{"args": dbetto.AttrsDict({"value": 2})},
modules={"lgdo": lgdo},
).view_as("np"),
[2, 4, 6, 8],
)

0 comments on commit 8244792

Please sign in to comment.