Skip to content

Commit

Permalink
Merge pull request #25 from uncountableinc/brendan/require-args
Browse files Browse the repository at this point in the history
Better torch support with tensors
  • Loading branch information
leb2 authored Feb 20, 2023
2 parents b3bdf8a + 3a18a7e commit be61c0b
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 30 deletions.
38 changes: 19 additions & 19 deletions hotxlfp/formulas/mathtrig.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,95 +23,95 @@ def ABS(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return abs(number)
return torch.abs(torch.tensor(number))


@dispatcher.register_for("ACOS")
def ACOS(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.acos(number)
return torch.acos(torch.tensor(number))


@dispatcher.register_for("ACOSH")
def ACOSH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.log(number + math.sqrt(number * number - 1))
return torch.log(torch.tensor(number) + torch.sqrt(torch.tensor(number) * torch.tensor(number) - 1))


@dispatcher.register_for("ACOT")
def ACOT(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.atan(1 / number)
return torch.atan(1 / torch.tensor(number))


@dispatcher.register_for("ACOTH")
def ACOTH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return 0.5 * math.log((number + 1) / (number - 1))
return 0.5 * torch.log((torch.tensor(number) + 1) / torch.tensor(number) - 1)


@dispatcher.register_for("SIN")
def SIN(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.sin(number)
return torch.sin(torch.tensor(number))


@dispatcher.register_for("SINH")
def SINH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.sinh(number)
return torch.sinh(torch.tensor(number))


@dispatcher.register_for("ASIN")
def ASIN(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.asin(number)
return torch.asin(torch.tensor(number))


@dispatcher.register_for("ASINH")
def ASINH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.asinh(number)
return torch.asinh(torch.tensor(number))


@dispatcher.register_for("COS")
def COS(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.cos(number)
return torch.cos(torch.tensor(number))


@dispatcher.register_for("COSH")
def COSH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.cosh(number)
return torch.cosh(torch.tensor(number))


@dispatcher.register_for("COT")
def COT(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.cos(number) / math.sin(number)
return torch.cos(torch.tensor(number)) / torch.sin(torch.tensor(number))


@dispatcher.register_for("TAN")
Expand All @@ -127,15 +127,15 @@ def TANH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.tanh(number)
return torch.tanh(torch.tensor(number))


@dispatcher.register_for("ATAN")
def ATAN(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.atan(number)
return torch.atan(torch.tensor(number))


@dispatcher.register_for("ATAN2")
Expand All @@ -146,31 +146,31 @@ def ATAN2(x_num, y_num):
y_num = utils.parse_number(x_num)
if isinstance(y_num, error.XLError):
return y_num
return math.atan2(x_num, y_num)
return torch.atan2(torch.tensor(x_num), torch.tensor(y_num))


@dispatcher.register_for("ATANH")
def ATANH(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.atanh(number)
return torch.atanh(torch.tensor(number))


@dispatcher.register_for("SQRT")
def SQRT(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.sqrt(number)
return torch.sqrt(torch.tensor(number))


@dispatcher.register_for("EXP")
def EXP(number):
number = utils.parse_number(number)
if isinstance(number, error.XLError):
return number
return math.e**number
return torch.exp(torch.tensor(number))


@dispatcher.register_for("LN")
Expand Down Expand Up @@ -200,7 +200,7 @@ def LOG10(number):

@dispatcher.register_for("PI")
def PI():
return math.pi
return torch.tensor(math.pi)


@dispatcher.register_for("ROUND")
Expand Down
28 changes: 25 additions & 3 deletions hotxlfp/formulas/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,31 @@
import torch


def _find_first_tensor(args) -> torch.Tensor | None:
for item in args:
if isinstance(item, torch.Tensor):
return item
return None

# Broadcasts args so that if there is a numeric arg, and a tensor arg, that the numeric
# arg will become a tensor of the same size as the numeric arg.
def broadcast_args(args):
first_tensor = _find_first_tensor(args)
if first_tensor is None:
return args

new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
new_args.append(arg)
else:
new_args.append(torch.ones_like(first_tensor) * arg)
return new_args


@dispatcher.register_for('AVERAGE')
def AVERAGE(*args):
return torch.mean(torch.tensor(torch.stack(args, dim=0), dtype=torch.double), dim=0)
return torch.mean(torch.tensor(torch.stack(broadcast_args(args), dim=0), dtype=torch.double), dim=0)


@dispatcher.register_for('AVEDEV')
Expand Down Expand Up @@ -71,7 +93,7 @@ def COUNTIF(args, criteria):

@dispatcher.register_for('MAX')
def MAX(*args):
tensors = [torch.tensor(val, dtype=torch.double) for val in args]
tensors = [torch.tensor(val, dtype=torch.double) for val in broadcast_args(args)]
return torch.max(torch.tensor(torch.stack(tensors, dim=0), dtype=torch.double), dim=0).values


Expand All @@ -87,7 +109,7 @@ def MEDIAN(*args):

@dispatcher.register_for('MIN')
def MIN(*args):
tensors = [torch.tensor(val, dtype=torch.double) for val in args]
tensors = [torch.tensor(val, dtype=torch.double) for val in broadcast_args(args)]
return torch.min(torch.tensor(torch.stack(tensors, dim=0), dtype=torch.double), dim=0).values


Expand Down
12 changes: 6 additions & 6 deletions hotxlfp/grammarparser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,18 @@ def p_expression_number(self, p):
"""
if p[1] == '.' and len(p) == 3:
p2 = p[2]
p[0] = lambda args: to_number('.' + p2)
p[0] = lambda args: to_number('.' + p2, args)
elif len(p) == 2:
p[0] = lambda args, p1=p[1]: to_number(p1)
p[0] = lambda args, p1=p[1]: to_number(p1, args)
elif p[2] == '.':
if len(p) == 4:
p[0] = lambda args, p1=p[1], p3=p[3]: to_number(p1 + '.' + p3)
p[0] = lambda args, p1=p[1], p3=p[3]: to_number(p1 + '.' + p3, args)
else:
p[0] = lambda args, p1=p[1]: to_number(p1)
p[0] = lambda args, p1=p[1]: to_number(p1, args)
elif p[2] == '^':
p[0] = lambda args, p1=p[1], p3=p[3]: to_number(p1)**to_number(p3)
p[0] = lambda args, p1=p[1], p3=p[3]: to_number(p1, args)**to_number(p3, args)
elif p[2] == '%':
p[0] = lambda args, p1=p[1]: to_number(p1) * 0.01
p[0] = lambda args, p1=p[1]: to_number(p1, args) * 0.01

def p_expression_string(self, p):
"""
Expand Down
13 changes: 12 additions & 1 deletion hotxlfp/helper/number.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
import torch

from .._compat import number_types, string_types

def to_number(number):
def to_number_wrapper(number):
if isinstance(number, number_types):
return number
if isinstance(number, string_types):
Expand All @@ -16,5 +18,14 @@ def to_number(number):
return 1 if number else 0
return number

def to_number(number, args = None):
number = to_number_wrapper(number)
if args is not None:
args_list = list(args.values())
if not (isinstance(number, torch.Tensor)) and len(args_list) > 0 and isinstance(args_list[0], torch.Tensor):
return torch.ones_like(args_list[0]) * number
return number


def invert_number(number):
return -1 * to_number(number)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="hotxlfp",
version="0.0.11-unc21",
version="0.0.11-unc22",
packages=[
"hotxlfp",
"hotxlfp._compat",
Expand Down
10 changes: 10 additions & 0 deletions tests/test_formula_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _test_equation(
except (error.XLError, TypeError):
assert should_fail
return
assert isinstance(result, torch.Tensor)
assert (
torch.abs(result - torch.tensor(answer)) < 0.000001
).all(), f"{result} != {answer}"
Expand Down Expand Up @@ -366,6 +367,15 @@ def test_if_statement_args(self):
_test_equation(equation="IF(a1 > 100, 40, IF(a1 > 1, 4, 56))", variables={"a1": [40]}, answer=[4])
_test_equation(equation="IF(a1 > 10, 40, IF(a1 > 10, 4))", variables={"a1": [4]}, should_fail=True)

def test_tensors(self):
_test_equation(equation="MIN(a1 * 2, 2, 23, a1)", variables={"a1": [5]}, answer=[2])
_test_equation(equation="MIN(2, a1 * 2)", variables={"a1": [5]}, answer=[2])
_test_equation(equation="MAX(a1 * 2, 2)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="MAX(2, a1 * 2)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="MAX(MAX(2, a1 * 2), 100)", variables={"a1": [5, 4]}, answer=[100, 100])
_test_equation(equation="5", variables={"a1": [5, 4]}, answer=[5])
_test_equation(equation="SQRT(100)", variables={"a1": [5]}, answer=[10])


if __name__ == "__main__":
unittest.main()

0 comments on commit be61c0b

Please sign in to comment.