diff --git a/hotxlfp/formulas/logic.py b/hotxlfp/formulas/logic.py index 3e3fd7a..719d1eb 100644 --- a/hotxlfp/formulas/logic.py +++ b/hotxlfp/formulas/logic.py @@ -7,69 +7,70 @@ from . import error from . import utils import torch +import numpy as np -@dispatcher.register_for('AND') +@dispatcher.register_for("AND") def AND(*args): args = utils.iflatten(args) return all(args) -@dispatcher.register_for('IF') +@dispatcher.register_for("IF") def IF(test, then, otherwise): if isinstance(test, error.XLError): return error.XLError - test_condition = torch.tensor(test, dtype=torch.bool) - if test_condition.all(): - if isinstance(then, error.XLError): - return then - return torch.tensor(then, dtype=torch.double) - elif (~test_condition).all(): - if isinstance(otherwise, error.XLError): - return otherwise - return torch.tensor(otherwise, dtype=torch.double) - if isinstance(then, error.XLError): return then if isinstance(otherwise, error.XLError): return otherwise - + + if ( + isinstance(then, str) + or isinstance(otherwise, str) + or (isinstance(then, np.ndarray) and then.dtype.kind == "U") + or (isinstance(otherwise, np.ndarray) and otherwise.dtype.kind == "U") + ): + then_str = np.array(then, dtype="U") + otherwise_str = np.array(otherwise, dtype="U") + return np.where(np.array(test, dtype="b"), then_str, otherwise_str) + return torch.where( - test_condition, + torch.tensor(test, dtype=torch.bool), torch.tensor(then, dtype=torch.double), torch.tensor(otherwise, dtype=torch.double), ) -@dispatcher.register_for('IFERROR') +@dispatcher.register_for("IFERROR") def IFERROR(value, value_if_error): return value if not isinstance(value, error.XLError) else value_if_error -@dispatcher.register_for('IFNA') +@dispatcher.register_for("IFNA") def IFNA(value, value_if_na): return value if value != error.NOT_AVAILABLE else value_if_na -@dispatcher.register_for('NOT') +@dispatcher.register_for("NOT") def NOT(boolean): return not boolean -@dispatcher.register_for('XOR') +@dispatcher.register_for("XOR") def XOR(*args): args = utils.iflatten(args) result = sum(bool(a) for a in args) return bool(result & 1) -@dispatcher.register_for('OR') +@dispatcher.register_for("OR") def OR(*args): args = utils.iflatten(args) return any(args) -@dispatcher.register_for('SWITCH') +@dispatcher.register_for("SWITCH") def SWITCH(target_value, *args): if len(args) <= 1: return error.NOT_AVAILABLE @@ -83,7 +84,7 @@ def SWITCH(target_value, *args): return error.NOT_AVAILABLE -@dispatcher.register_for('IFS') +@dispatcher.register_for("IFS") def IFS(*args): for pair in zip(args[::2], args[1::2]): if pair[0]: @@ -93,11 +94,12 @@ def IFS(*args): # Compatibility functions -@dispatcher.register_for('TRUE') + +@dispatcher.register_for("TRUE") def TRUE(): return True -@dispatcher.register_for('FALSE') +@dispatcher.register_for("FALSE") def FALSE(): return False diff --git a/setup.py b/setup.py index 7f103cc..d7ab39e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="hotxlfp", - version="0.0.11+unc.29", + version="0.0.11+unc.30", packages=[ "hotxlfp", "hotxlfp._compat", diff --git a/tests/test_formula_parser.py b/tests/test_formula_parser.py index b843d50..f51f022 100644 --- a/tests/test_formula_parser.py +++ b/tests/test_formula_parser.py @@ -4,6 +4,7 @@ import torch from hotxlfp import Parser from math import pi +import numpy as np def _test_equation( @@ -20,10 +21,13 @@ def _test_equation( except (error.XLError, TypeError): assert should_fail return - assert isinstance(result, torch.Tensor) - assert ( - torch.abs(result - torch.tensor(answer)) < 0.000001 - ).all(), f"{result} != {answer}" + if isinstance(result, np.ndarray): + assert (result == answer).all(), f"{result} != {answer}" + else: + assert isinstance(result, torch.Tensor) + assert ( + torch.abs(result - torch.tensor(answer)) < 0.000001 + ).all(), f"{result} != {answer}" class TestFormulaParser(unittest.TestCase): @@ -347,7 +351,7 @@ def test_decimal(self): ) def test_order_of_operations(self): - _test_equation(equation="2^-2-1", variables={"a1": [1]}, answer=[-.75]) + _test_equation(equation="2^-2-1", variables={"a1": [1]}, answer=[-0.75]) _test_equation(equation="(2^(-1))", variables={"a1": [1]}, answer=[0.5]) _test_equation(equation="((2^(-1))-1)", variables={"a1": [1]}, answer=[-0.5]) _test_equation(equation="2^2-1", variables={"a1": [1]}, answer=[3]) @@ -360,55 +364,116 @@ def test_order_of_operations(self): _test_equation(equation="1 + 2 * 3 - 4", variables={"a1": [1]}, answer=[3]) def test_if_statement_args(self): - _test_equation(equation="IF(a1 > 10, 1, 100, 400)", variables={"a1": [4]}, should_fail=True) - _test_equation(equation="IF(a1 > 10, 400)", variables={"a1": [4]}, should_fail=True) - _test_equation(equation="IF(a1 > 10, )", variables={"a1": [4]}, should_fail=True) + _test_equation( + equation="IF(a1 > 10, 1, 100, 400)", variables={"a1": [4]}, should_fail=True + ) + _test_equation( + equation="IF(a1 > 10, 400)", variables={"a1": [4]}, should_fail=True + ) + _test_equation( + equation="IF(a1 > 10, )", variables={"a1": [4]}, should_fail=True + ) _test_equation(equation="IF(,, )", variables={"a1": [4]}, should_fail=True) - _test_equation(equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))", variables={"a1": [40]}, answer=[4]) - _test_equation(equation="IF(a1 > 10, 40, IF(a1 > 10, 4))", variables={"a1": [4]}, should_fail=True) + _test_equation( + equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))", + variables={"a1": [40]}, + answer=[4], + ) + _test_equation( + equation="IF(a1 > 10, 40, IF(a1 > 10, 4))", + variables={"a1": [4]}, + should_fail=True, + ) + _test_equation( + equation="IF(a1 > 10, 'abc', 'def')", variables={"a1": [4]}, answer=["def"] + ) + _test_equation( + equation="IF(a1 > 100, 'abc', IF(a1 > 1, 4, 56))", + variables={"a1": [40]}, + answer=["4.0"], + ) def test_tensors(self): - _test_equation(equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2]) + _test_equation( + equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2] + ) _test_equation(equation="MIN(2, a1 * 2)", variables={"a1": [5]}, answer=[2]) _test_equation(equation="MAX(a1 * 2, 2)", variables={"a1": [5]}, answer=[10]) _test_equation(equation="MAX(2, a1 * 2)", variables={"a1": [5]}, answer=[10]) - _test_equation(equation="MAX(MAX(2, a1 * 2), 100)", variables={"a1": [5, 4]}, answer=[100, 100]) + _test_equation( + equation="MAX(MAX(2, a1 * 2), 100)", + variables={"a1": [5, 4]}, + answer=[100, 100], + ) _test_equation(equation="5", variables={"a1": [5, 4]}, answer=[5]) _test_equation(equation="SQRT(100)", variables={"a1": [5]}, answer=[10]) - _test_equation(equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1]) - _test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [1, 2]}, answer=[1, 2]) - _test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [2]}, answer=[2, 2]) - _test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1]) - _test_equation(equation="IF(a1 <> a2, 1, 0)", variables={"a1": [3, 2], "a2": [3, 3]}, answer=[0, 1]) - _test_equation(equation="IF(a1 <> a2, 1, 0)", variables={"a1": 4, "a2": 2}, answer=1) - _test_equation(equation="IF(a1 < a2, 1, 0)", variables={"a1": [1, 2], "a2": [0, 0]}, answer=[0, 0]) - _test_equation(equation="IF(a1 < a2, 1, 0)", variables={"a1": 2, "a2": 3}, answer=1) + _test_equation( + equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1] + ) + _test_equation( + equation="CEILING(a1, a2)", + variables={"a1": [0.5, 0.5], "a2": [1, 2]}, + answer=[1, 2], + ) + _test_equation( + equation="CEILING(a1, a2)", + variables={"a1": [0.5, 0.5], "a2": [2]}, + answer=[2, 2], + ) + _test_equation( + equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1] + ) + _test_equation( + equation="IF(a1 <> a2, 1, 0)", + variables={"a1": [3, 2], "a2": [3, 3]}, + answer=[0, 1], + ) + _test_equation( + equation="IF(a1 <> a2, 1, 0)", variables={"a1": 4, "a2": 2}, answer=1 + ) + _test_equation( + equation="IF(a1 < a2, 1, 0)", + variables={"a1": [1, 2], "a2": [0, 0]}, + answer=[0, 0], + ) + _test_equation( + equation="IF(a1 < a2, 1, 0)", variables={"a1": 2, "a2": 3}, answer=1 + ) def test_scientific_notation(self): - _test_equation(equation="2e2", variables={"a1" : [1.1]}, answer=[200]) - _test_equation(equation="5(m)", variables={"m" : [10]}, answer=[50]) - _test_equation(equation="5(e)", variables={"e" : [10]}, answer=[50]) - _test_equation(equation="5(evar)", variables={"evar" : [10]}, answer=[50]) - _test_equation(equation="5(vare)", variables={"vare" : [10]}, answer=[50]) - _test_equation(equation="1 e 2", variables={"a1" : [1.1]}, answer=[100]) - _test_equation(equation="1e2", variables={"a1" : [1.1]}, answer=[100]) - _test_equation(equation="2*1e2", variables={"a1" : [1.1]}, answer=[200]) - _test_equation(equation="2*1e2^3", variables={"a1" : [1.1]}, answer=[2000000]) - _test_equation(equation="(2*1e2)^3", variables={"a1" : [1.1]}, answer=[8000000]) - _test_equation(equation="(2)e(4)", variables={"a1" : [1.1]}, answer=[8000000], should_fail=True) - _test_equation(equation="(2)e(4)", variables={"a1" : [1.1]}, should_fail=True) - _test_equation(equation="0.2e2", variables={"a1" : [1.1]}, answer=20) - _test_equation(equation="0.2e0.2", variables={"a1" : [1.1]}, answer=0.2 * (10 ** 0.2)) - _test_equation(equation="2e-1", variables={"a1" : [1.1]}, answer=0.2) - _test_equation(equation="-2e1", variables={"a1" : [1.1]}, answer=-20) - _test_equation(equation="-2e-1", variables={"a1" : [1.1]}, answer=-0.2) - _test_equation(equation="-2e-.1", variables={"a1" : [1.1]}, answer=-2 * (10 ** (-.1))) - - - _test_equation(equation="2E2", variables={"a1" : [1.1]}, answer=[200]) - _test_equation(equation="2*1E2", variables={"a1" : [1.1]}, answer=[200]) - _test_equation(equation="0.2E2", variables={"a1" : [1.1]}, answer=20) - _test_equation(equation="-2E-1", variables={"a1" : [1.1]}, answer=-0.2) + _test_equation(equation="2e2", variables={"a1": [1.1]}, answer=[200]) + _test_equation(equation="5(m)", variables={"m": [10]}, answer=[50]) + _test_equation(equation="5(e)", variables={"e": [10]}, answer=[50]) + _test_equation(equation="5(evar)", variables={"evar": [10]}, answer=[50]) + _test_equation(equation="5(vare)", variables={"vare": [10]}, answer=[50]) + _test_equation(equation="1 e 2", variables={"a1": [1.1]}, answer=[100]) + _test_equation(equation="1e2", variables={"a1": [1.1]}, answer=[100]) + _test_equation(equation="2*1e2", variables={"a1": [1.1]}, answer=[200]) + _test_equation(equation="2*1e2^3", variables={"a1": [1.1]}, answer=[2000000]) + _test_equation(equation="(2*1e2)^3", variables={"a1": [1.1]}, answer=[8000000]) + _test_equation( + equation="(2)e(4)", + variables={"a1": [1.1]}, + answer=[8000000], + should_fail=True, + ) + _test_equation(equation="(2)e(4)", variables={"a1": [1.1]}, should_fail=True) + _test_equation(equation="0.2e2", variables={"a1": [1.1]}, answer=20) + _test_equation( + equation="0.2e0.2", variables={"a1": [1.1]}, answer=0.2 * (10**0.2) + ) + _test_equation(equation="2e-1", variables={"a1": [1.1]}, answer=0.2) + _test_equation(equation="-2e1", variables={"a1": [1.1]}, answer=-20) + _test_equation(equation="-2e-1", variables={"a1": [1.1]}, answer=-0.2) + _test_equation( + equation="-2e-.1", variables={"a1": [1.1]}, answer=-2 * (10 ** (-0.1)) + ) + + _test_equation(equation="2E2", variables={"a1": [1.1]}, answer=[200]) + _test_equation(equation="2*1E2", variables={"a1": [1.1]}, answer=[200]) + _test_equation(equation="0.2E2", variables={"a1": [1.1]}, answer=20) + _test_equation(equation="-2E-1", variables={"a1": [1.1]}, answer=-0.2) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()