Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add string support to if statement #33

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions hotxlfp/formulas/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,69 +7,70 @@
from . import error
from . import utils
import torch
import numpy as np


@dispatcher.register_for('AND')
@dispatcher.register_for("AND")
def AND(*args):
args = utils.iflatten(args)
return all(args)


@dispatcher.register_for('IF')
@dispatcher.register_for("IF")
def IF(test, then, otherwise):
if isinstance(test, error.XLError):
return error.XLError
test_condition = torch.tensor(test, dtype=torch.bool)
if test_condition.all():
if isinstance(then, error.XLError):
return then
return torch.tensor(then, dtype=torch.double)
elif (~test_condition).all():
if isinstance(otherwise, error.XLError):
return otherwise
return torch.tensor(otherwise, dtype=torch.double)

if isinstance(then, error.XLError):
return then
if isinstance(otherwise, error.XLError):
return otherwise


if (
isinstance(then, str)
or isinstance(otherwise, str)
or (isinstance(then, np.ndarray) and then.dtype.kind == "U")
or (isinstance(otherwise, np.ndarray) and otherwise.dtype.kind == "U")
):
then_str = np.array(then, dtype="U")
otherwise_str = np.array(otherwise, dtype="U")
return np.where(np.array(test, dtype="b"), then_str, otherwise_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good strategy. I'll test to see if this np.array strategy works on our platform, but it also looks like it shouldn't cause any regressions because of the type checks you are doing.


return torch.where(
test_condition,
torch.tensor(test, dtype=torch.bool),
torch.tensor(then, dtype=torch.double),
torch.tensor(otherwise, dtype=torch.double),
)


@dispatcher.register_for('IFERROR')
@dispatcher.register_for("IFERROR")
def IFERROR(value, value_if_error):
return value if not isinstance(value, error.XLError) else value_if_error


@dispatcher.register_for('IFNA')
@dispatcher.register_for("IFNA")
def IFNA(value, value_if_na):
return value if value != error.NOT_AVAILABLE else value_if_na


@dispatcher.register_for('NOT')
@dispatcher.register_for("NOT")
def NOT(boolean):
return not boolean


@dispatcher.register_for('XOR')
@dispatcher.register_for("XOR")
def XOR(*args):
args = utils.iflatten(args)
result = sum(bool(a) for a in args)
return bool(result & 1)


@dispatcher.register_for('OR')
@dispatcher.register_for("OR")
def OR(*args):
args = utils.iflatten(args)
return any(args)


@dispatcher.register_for('SWITCH')
@dispatcher.register_for("SWITCH")
def SWITCH(target_value, *args):
if len(args) <= 1:
return error.NOT_AVAILABLE
Expand All @@ -83,7 +84,7 @@ def SWITCH(target_value, *args):
return error.NOT_AVAILABLE


@dispatcher.register_for('IFS')
@dispatcher.register_for("IFS")
def IFS(*args):
for pair in zip(args[::2], args[1::2]):
if pair[0]:
Expand All @@ -93,11 +94,12 @@ def IFS(*args):

# Compatibility functions

@dispatcher.register_for('TRUE')

@dispatcher.register_for("TRUE")
def TRUE():
return True


@dispatcher.register_for('FALSE')
@dispatcher.register_for("FALSE")
def FALSE():
return False
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="hotxlfp",
version="0.0.11+unc.29",
version="0.0.11+unc.30",
packages=[
"hotxlfp",
"hotxlfp._compat",
Expand Down
155 changes: 110 additions & 45 deletions tests/test_formula_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from hotxlfp import Parser
from math import pi
import numpy as np


def _test_equation(
Expand All @@ -20,10 +21,13 @@ def _test_equation(
except (error.XLError, TypeError):
assert should_fail
return
assert isinstance(result, torch.Tensor)
assert (
torch.abs(result - torch.tensor(answer)) < 0.000001
).all(), f"{result} != {answer}"
if isinstance(result, np.ndarray):
assert (result == answer).all(), f"{result} != {answer}"
else:
assert isinstance(result, torch.Tensor)
assert (
torch.abs(result - torch.tensor(answer)) < 0.000001
).all(), f"{result} != {answer}"


class TestFormulaParser(unittest.TestCase):
Expand Down Expand Up @@ -347,7 +351,7 @@ def test_decimal(self):
)

def test_order_of_operations(self):
_test_equation(equation="2^-2-1", variables={"a1": [1]}, answer=[-.75])
_test_equation(equation="2^-2-1", variables={"a1": [1]}, answer=[-0.75])
_test_equation(equation="(2^(-1))", variables={"a1": [1]}, answer=[0.5])
_test_equation(equation="((2^(-1))-1)", variables={"a1": [1]}, answer=[-0.5])
_test_equation(equation="2^2-1", variables={"a1": [1]}, answer=[3])
Expand All @@ -360,55 +364,116 @@ def test_order_of_operations(self):
_test_equation(equation="1 + 2 * 3 - 4", variables={"a1": [1]}, answer=[3])

def test_if_statement_args(self):
_test_equation(equation="IF(a1 > 10, 1, 100, 400)", variables={"a1": [4]}, should_fail=True)
_test_equation(equation="IF(a1 > 10, 400)", variables={"a1": [4]}, should_fail=True)
_test_equation(equation="IF(a1 > 10, )", variables={"a1": [4]}, should_fail=True)
_test_equation(
equation="IF(a1 > 10, 1, 100, 400)", variables={"a1": [4]}, should_fail=True
)
_test_equation(
equation="IF(a1 > 10, 400)", variables={"a1": [4]}, should_fail=True
)
_test_equation(
equation="IF(a1 > 10, )", variables={"a1": [4]}, should_fail=True
)
_test_equation(equation="IF(,, )", variables={"a1": [4]}, should_fail=True)
_test_equation(equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))", variables={"a1": [40]}, answer=[4])
_test_equation(equation="IF(a1 > 10, 40, IF(a1 > 10, 4))", variables={"a1": [4]}, should_fail=True)
_test_equation(
equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))",
variables={"a1": [40]},
answer=[4],
)
_test_equation(
equation="IF(a1 > 10, 40, IF(a1 > 10, 4))",
variables={"a1": [4]},
should_fail=True,
)
_test_equation(
equation="IF(a1 > 10, 'abc', 'def')", variables={"a1": [4]}, answer=["def"]
)
_test_equation(
equation="IF(a1 > 100, 'abc', IF(a1 > 1, 4, 56))",
variables={"a1": [40]},
answer=["4.0"],
)

def test_tensors(self):
_test_equation(equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2])
_test_equation(
equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2]
)
_test_equation(equation="MIN(2, a1 * 2)", variables={"a1": [5]}, answer=[2])
_test_equation(equation="MAX(a1 * 2, 2)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="MAX(2, a1 * 2)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="MAX(MAX(2, a1 * 2), 100)", variables={"a1": [5, 4]}, answer=[100, 100])
_test_equation(
equation="MAX(MAX(2, a1 * 2), 100)",
variables={"a1": [5, 4]},
answer=[100, 100],
)
_test_equation(equation="5", variables={"a1": [5, 4]}, answer=[5])
_test_equation(equation="SQRT(100)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [1, 2]}, answer=[1, 2])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [2]}, answer=[2, 2])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1])
_test_equation(equation="IF(a1 <> a2, 1, 0)", variables={"a1": [3, 2], "a2": [3, 3]}, answer=[0, 1])
_test_equation(equation="IF(a1 <> a2, 1, 0)", variables={"a1": 4, "a2": 2}, answer=1)
_test_equation(equation="IF(a1 < a2, 1, 0)", variables={"a1": [1, 2], "a2": [0, 0]}, answer=[0, 0])
_test_equation(equation="IF(a1 < a2, 1, 0)", variables={"a1": 2, "a2": 3}, answer=1)
_test_equation(
equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1]
)
_test_equation(
equation="CEILING(a1, a2)",
variables={"a1": [0.5, 0.5], "a2": [1, 2]},
answer=[1, 2],
)
_test_equation(
equation="CEILING(a1, a2)",
variables={"a1": [0.5, 0.5], "a2": [2]},
answer=[2, 2],
)
_test_equation(
equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1]
)
_test_equation(
equation="IF(a1 <> a2, 1, 0)",
variables={"a1": [3, 2], "a2": [3, 3]},
answer=[0, 1],
)
_test_equation(
equation="IF(a1 <> a2, 1, 0)", variables={"a1": 4, "a2": 2}, answer=1
)
_test_equation(
equation="IF(a1 < a2, 1, 0)",
variables={"a1": [1, 2], "a2": [0, 0]},
answer=[0, 0],
)
_test_equation(
equation="IF(a1 < a2, 1, 0)", variables={"a1": 2, "a2": 3}, answer=1
)

def test_scientific_notation(self):
_test_equation(equation="2e2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="5(m)", variables={"m" : [10]}, answer=[50])
_test_equation(equation="5(e)", variables={"e" : [10]}, answer=[50])
_test_equation(equation="5(evar)", variables={"evar" : [10]}, answer=[50])
_test_equation(equation="5(vare)", variables={"vare" : [10]}, answer=[50])
_test_equation(equation="1 e 2", variables={"a1" : [1.1]}, answer=[100])
_test_equation(equation="1e2", variables={"a1" : [1.1]}, answer=[100])
_test_equation(equation="2*1e2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="2*1e2^3", variables={"a1" : [1.1]}, answer=[2000000])
_test_equation(equation="(2*1e2)^3", variables={"a1" : [1.1]}, answer=[8000000])
_test_equation(equation="(2)e(4)", variables={"a1" : [1.1]}, answer=[8000000], should_fail=True)
_test_equation(equation="(2)e(4)", variables={"a1" : [1.1]}, should_fail=True)
_test_equation(equation="0.2e2", variables={"a1" : [1.1]}, answer=20)
_test_equation(equation="0.2e0.2", variables={"a1" : [1.1]}, answer=0.2 * (10 ** 0.2))
_test_equation(equation="2e-1", variables={"a1" : [1.1]}, answer=0.2)
_test_equation(equation="-2e1", variables={"a1" : [1.1]}, answer=-20)
_test_equation(equation="-2e-1", variables={"a1" : [1.1]}, answer=-0.2)
_test_equation(equation="-2e-.1", variables={"a1" : [1.1]}, answer=-2 * (10 ** (-.1)))


_test_equation(equation="2E2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="2*1E2", variables={"a1" : [1.1]}, answer=[200])
_test_equation(equation="0.2E2", variables={"a1" : [1.1]}, answer=20)
_test_equation(equation="-2E-1", variables={"a1" : [1.1]}, answer=-0.2)
_test_equation(equation="2e2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="5(m)", variables={"m": [10]}, answer=[50])
_test_equation(equation="5(e)", variables={"e": [10]}, answer=[50])
_test_equation(equation="5(evar)", variables={"evar": [10]}, answer=[50])
_test_equation(equation="5(vare)", variables={"vare": [10]}, answer=[50])
_test_equation(equation="1 e 2", variables={"a1": [1.1]}, answer=[100])
_test_equation(equation="1e2", variables={"a1": [1.1]}, answer=[100])
_test_equation(equation="2*1e2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="2*1e2^3", variables={"a1": [1.1]}, answer=[2000000])
_test_equation(equation="(2*1e2)^3", variables={"a1": [1.1]}, answer=[8000000])
_test_equation(
equation="(2)e(4)",
variables={"a1": [1.1]},
answer=[8000000],
should_fail=True,
)
_test_equation(equation="(2)e(4)", variables={"a1": [1.1]}, should_fail=True)
_test_equation(equation="0.2e2", variables={"a1": [1.1]}, answer=20)
_test_equation(
equation="0.2e0.2", variables={"a1": [1.1]}, answer=0.2 * (10**0.2)
)
_test_equation(equation="2e-1", variables={"a1": [1.1]}, answer=0.2)
_test_equation(equation="-2e1", variables={"a1": [1.1]}, answer=-20)
_test_equation(equation="-2e-1", variables={"a1": [1.1]}, answer=-0.2)
_test_equation(
equation="-2e-.1", variables={"a1": [1.1]}, answer=-2 * (10 ** (-0.1))
)

_test_equation(equation="2E2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="2*1E2", variables={"a1": [1.1]}, answer=[200])
_test_equation(equation="0.2E2", variables={"a1": [1.1]}, answer=20)
_test_equation(equation="-2E-1", variables={"a1": [1.1]}, answer=-0.2)


if __name__ == "__main__":
unittest.main()
unittest.main()