Skip to content

Commit

Permalink
Merge pull request #33 from uncountableinc/t/add-str-support-to-if
Browse files Browse the repository at this point in the history
add string support to if statement
  • Loading branch information
leb2 authored Nov 10, 2023
2 parents 9c735d3 + fd4396e commit 4261fab
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 69 deletions.
48 changes: 25 additions & 23 deletions hotxlfp/formulas/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,69 +7,70 @@
from . import error
from . import utils
import torch
import numpy as np


@dispatcher.register_for('AND')
@dispatcher.register_for("AND")
def AND(*args):
args = utils.iflatten(args)
return all(args)


@dispatcher.register_for('IF')
@dispatcher.register_for("IF")
def IF(test, then, otherwise):
if isinstance(test, error.XLError):
return error.XLError
test_condition = torch.tensor(test, dtype=torch.bool)
if test_condition.all():
if isinstance(then, error.XLError):
return then
return torch.tensor(then, dtype=torch.double)
elif (~test_condition).all():
if isinstance(otherwise, error.XLError):
return otherwise
return torch.tensor(otherwise, dtype=torch.double)

if isinstance(then, error.XLError):
return then
if isinstance(otherwise, error.XLError):
return otherwise


if (
isinstance(then, str)
or isinstance(otherwise, str)
or (isinstance(then, np.ndarray) and then.dtype.kind == "U")
or (isinstance(otherwise, np.ndarray) and otherwise.dtype.kind == "U")
):
then_str = np.array(then, dtype="U")
otherwise_str = np.array(otherwise, dtype="U")
return np.where(np.array(test, dtype="b"), then_str, otherwise_str)

return torch.where(
test_condition,
torch.tensor(test, dtype=torch.bool),
torch.tensor(then, dtype=torch.double),
torch.tensor(otherwise, dtype=torch.double),
)


@dispatcher.register_for('IFERROR')
@dispatcher.register_for("IFERROR")
def IFERROR(value, value_if_error):
return value if not isinstance(value, error.XLError) else value_if_error


@dispatcher.register_for('IFNA')
@dispatcher.register_for("IFNA")
def IFNA(value, value_if_na):
return value if value != error.NOT_AVAILABLE else value_if_na


@dispatcher.register_for('NOT')
@dispatcher.register_for("NOT")
def NOT(boolean):
return not boolean


@dispatcher.register_for('XOR')
@dispatcher.register_for("XOR")
def XOR(*args):
args = utils.iflatten(args)
result = sum(bool(a) for a in args)
return bool(result & 1)


@dispatcher.register_for('OR')
@dispatcher.register_for("OR")
def OR(*args):
args = utils.iflatten(args)
return any(args)


@dispatcher.register_for('SWITCH')
@dispatcher.register_for("SWITCH")
def SWITCH(target_value, *args):
if len(args) <= 1:
return error.NOT_AVAILABLE
Expand All @@ -83,7 +84,7 @@ def SWITCH(target_value, *args):
return error.NOT_AVAILABLE


@dispatcher.register_for('IFS')
@dispatcher.register_for("IFS")
def IFS(*args):
for pair in zip(args[::2], args[1::2]):
if pair[0]:
Expand All @@ -93,11 +94,12 @@ def IFS(*args):

# Compatibility functions

@dispatcher.register_for('TRUE')

@dispatcher.register_for("TRUE")
def TRUE():
return True


@dispatcher.register_for('FALSE')
@dispatcher.register_for("FALSE")
def FALSE():
return False
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="hotxlfp",
version="0.0.11+unc.29",
version="0.0.11+unc.30",
packages=[
"hotxlfp",
"hotxlfp._compat",
Expand Down
155 changes: 110 additions & 45 deletions tests/test_formula_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from hotxlfp import Parser
from math import pi
import numpy as np


def _test_equation(
Expand All @@ -20,10 +21,13 @@ def _test_equation(
except (error.XLError, TypeError):
assert should_fail
return
assert isinstance(result, torch.Tensor)
assert (
torch.abs(result - torch.tensor(answer)) < 0.000001
).all(), f"{result} != {answer}"
if isinstance(result, np.ndarray):
assert (result == answer).all(), f"{result} != {answer}"
else:
assert isinstance(result, torch.Tensor)
assert (
torch.abs(result - torch.tensor(answer)) < 0.000001
).all(), f"{result} != {answer}"


class TestFormulaParser(unittest.TestCase):
Expand Down Expand Up @@ -347,7 +351,7 @@ def test_decimal(self):
)

def test_order_of_operations(self):
_test_equation(equation="2^-2-1", variables={"a1": [1]}, answer=[-.75])
_test_equation(equation="2^-2-1", variables={"a1": [1]}, answer=[-0.75])
_test_equation(equation="(2^(-1))", variables={"a1": [1]}, answer=[0.5])
_test_equation(equation="((2^(-1))-1)", variables={"a1": [1]}, answer=[-0.5])
_test_equation(equation="2^2-1", variables={"a1": [1]}, answer=[3])
Expand All @@ -360,55 +364,116 @@ def test_order_of_operations(self):
_test_equation(equation="1 + 2 * 3 - 4", variables={"a1": [1]}, answer=[3])

def test_if_statement_args(self):
_test_equation(equation="IF(a1 > 10, 1, 100, 400)", variables={"a1": [4]}, should_fail=True)
_test_equation(equation="IF(a1 > 10, 400)", variables={"a1": [4]}, should_fail=True)
_test_equation(equation="IF(a1 > 10, )", variables={"a1": [4]}, should_fail=True)
_test_equation(
equation="IF(a1 > 10, 1, 100, 400)", variables={"a1": [4]}, should_fail=True
)
_test_equation(
equation="IF(a1 > 10, 400)", variables={"a1": [4]}, should_fail=True
)
_test_equation(
equation="IF(a1 > 10, )", variables={"a1": [4]}, should_fail=True
)
_test_equation(equation="IF(,, )", variables={"a1": [4]}, should_fail=True)
_test_equation(equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))", variables={"a1": [40]}, answer=[4])
_test_equation(equation="IF(a1 > 10, 40, IF(a1 > 10, 4))", variables={"a1": [4]}, should_fail=True)
_test_equation(
equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))",
variables={"a1": [40]},
answer=[4],
)
_test_equation(
equation="IF(a1 > 10, 40, IF(a1 > 10, 4))",
variables={"a1": [4]},
should_fail=True,
)
_test_equation(
equation="IF(a1 > 10, 'abc', 'def')", variables={"a1": [4]}, answer=["def"]
)
_test_equation(
equation="IF(a1 > 100, 'abc', IF(a1 > 1, 4, 56))",
variables={"a1": [40]},
answer=["4.0"],
)

def test_tensors(self):
_test_equation(equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2])
_test_equation(
equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2]
)
_test_equation(equation="MIN(2, a1 * 2)", variables={"a1": [5]}, answer=[2])
_test_equation(equation="MAX(a1 * 2, 2)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="MAX(2, a1 * 2)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="MAX(MAX(2, a1 * 2), 100)", variables={"a1": [5, 4]}, answer=[100, 100])
_test_equation(
equation="MAX(MAX(2, a1 * 2), 100)",
variables={"a1": [5, 4]},
answer=[100, 100],
)
_test_equation(equation="5", variables={"a1": [5, 4]}, answer=[5])
_test_equation(equation="SQRT(100)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [1, 2]}, answer=[1, 2])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [2]}, answer=[2, 2])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1])
_test_equation(equation="IF(a1 <> a2, 1, 0)", variables={"a1": [3, 2], "a2": [3, 3]}, answer=[0, 1])
_test_equation(equation="IF(a1 <> a2, 1, 0)", variables={"a1": 4, "a2": 2}, answer=1)
_test_equation(equation="IF(a1 < a2, 1, 0)", variables={"a1": [1, 2], "a2": [0, 0]}, answer=[0, 0])
_test_equation(equation="IF(a1 < a2, 1, 0)", variables={"a1": 2, "a2": 3}, answer=1)
_test_equation(
equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1]
)
_test_equation(
equation="CEILING(a1, a2)",
variables={"a1": [0.5, 0.5], "a2": [1, 2]},
answer=[1, 2],
)
_test_equation(
equation="CEILING(a1, a2)",
variables={"a1": [0.5, 0.5], "a2": [2]},
answer=[2, 2],
)
_test_equation(
equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1]
)
_test_equation(
equation="IF(a1 <> a2, 1, 0)",
variables={"a1": [3, 2], "a2": [3, 3]},
answer=[0, 1],
)
_test_equation(
equation="IF(a1 <> a2, 1, 0)", variables={"a1": 4, "a2": 2}, answer=1
)
_test_equation(
equation="IF(a1 < a2, 1, 0)",
variables={"a1": [1, 2], "a2": [0, 0]},
answer=[0, 0],
)
_test_equation(
equation="IF(a1 < a2, 1, 0)", variables={"a1": 2, "a2": 3}, answer=1
)

def test_scientific_notation(self):
_test_equation(equation="2e2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="5(m)", variables={"m" : [10]}, answer=[50])
_test_equation(equation="5(e)", variables={"e" : [10]}, answer=[50])
_test_equation(equation="5(evar)", variables={"evar" : [10]}, answer=[50])
_test_equation(equation="5(vare)", variables={"vare" : [10]}, answer=[50])
_test_equation(equation="1 e 2", variables={"a1" : [1.1]}, answer=[100])
_test_equation(equation="1e2", variables={"a1" : [1.1]}, answer=[100])
_test_equation(equation="2*1e2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="2*1e2^3", variables={"a1" : [1.1]}, answer=[2000000])
_test_equation(equation="(2*1e2)^3", variables={"a1" : [1.1]}, answer=[8000000])
_test_equation(equation="(2)e(4)", variables={"a1" : [1.1]}, answer=[8000000], should_fail=True)
_test_equation(equation="(2)e(4)", variables={"a1" : [1.1]}, should_fail=True)
_test_equation(equation="0.2e2", variables={"a1" : [1.1]}, answer=20)
_test_equation(equation="0.2e0.2", variables={"a1" : [1.1]}, answer=0.2 * (10 ** 0.2))
_test_equation(equation="2e-1", variables={"a1" : [1.1]}, answer=0.2)
_test_equation(equation="-2e1", variables={"a1" : [1.1]}, answer=-20)
_test_equation(equation="-2e-1", variables={"a1" : [1.1]}, answer=-0.2)
_test_equation(equation="-2e-.1", variables={"a1" : [1.1]}, answer=-2 * (10 ** (-.1)))


_test_equation(equation="2E2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="2*1E2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="0.2E2", variables={"a1" : [1.1]}, answer=20)
_test_equation(equation="-2E-1", variables={"a1" : [1.1]}, answer=-0.2)
_test_equation(equation="2e2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="5(m)", variables={"m": [10]}, answer=[50])
_test_equation(equation="5(e)", variables={"e": [10]}, answer=[50])
_test_equation(equation="5(evar)", variables={"evar": [10]}, answer=[50])
_test_equation(equation="5(vare)", variables={"vare": [10]}, answer=[50])
_test_equation(equation="1 e 2", variables={"a1": [1.1]}, answer=[100])
_test_equation(equation="1e2", variables={"a1": [1.1]}, answer=[100])
_test_equation(equation="2*1e2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="2*1e2^3", variables={"a1": [1.1]}, answer=[2000000])
_test_equation(equation="(2*1e2)^3", variables={"a1": [1.1]}, answer=[8000000])
_test_equation(
equation="(2)e(4)",
variables={"a1": [1.1]},
answer=[8000000],
should_fail=True,
)
_test_equation(equation="(2)e(4)", variables={"a1": [1.1]}, should_fail=True)
_test_equation(equation="0.2e2", variables={"a1": [1.1]}, answer=20)
_test_equation(
equation="0.2e0.2", variables={"a1": [1.1]}, answer=0.2 * (10**0.2)
)
_test_equation(equation="2e-1", variables={"a1": [1.1]}, answer=0.2)
_test_equation(equation="-2e1", variables={"a1": [1.1]}, answer=-20)
_test_equation(equation="-2e-1", variables={"a1": [1.1]}, answer=-0.2)
_test_equation(
equation="-2e-.1", variables={"a1": [1.1]}, answer=-2 * (10 ** (-0.1))
)

_test_equation(equation="2E2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="2*1E2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="0.2E2", variables={"a1": [1.1]}, answer=20)
_test_equation(equation="-2E-1", variables={"a1": [1.1]}, answer=-0.2)


if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 4261fab

Please sign in to comment.