Skip to content

Commit

Permalink
Merge pull request #30 from uncountableinc/catherine/ceiling-handle-t…
Browse files Browse the repository at this point in the history
…ensor-input

handle tensor inputs to ceiling function
  • Loading branch information
catherinelasersohn authored Oct 25, 2023
2 parents 0abd9c3 + 9e86881 commit 4828d52
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
37 changes: 25 additions & 12 deletions hotxlfp/formulas/mathtrig.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,23 +245,36 @@ def SUMIF(args, criteria):

@dispatcher.register_for("CEILING", "CEILING.MATH", "CEILING.PRECISE")
def CEILING(number, significance=1):
number = utils.parse_number(number)
number = torch.tensor(utils.parse_number(number))
significance = utils.parse_number(significance)
if not isinstance(significance, torch.Tensor) or significance.size(dim=0) == 1:
significance = torch.broadcast_to(torch.tensor(significance), number.size())

if utils.any_is_error((number, significance)):
return error.VALUE
if significance == 0:
return 0
if number.size(dim=0) != significance.size(dim=0):
return error.VALUE

positive_significance = significance > 0
significance = abs(significance)
if number >= 0:
return math.ceil(number / significance) * significance
else:
if positive_significance:
return -1 * math.floor(abs(number) / significance) * significance
else:
return -1 * math.ceil(abs(number) / significance) * significance
positive_number = torch.where(
(number >= 0) & (significance != 0),
torch.ceil(number / torch.abs(significance)) * torch.abs(significance),
0,
)

positive_significance = torch.where(
(number < 0) & (significance > 0),
-1 * torch.floor(torch.abs(number) / significance) * significance,
0,
)

negative_significance = torch.where(
(number < 0) & (significance < 0),
-1 * torch.ceil(torch.abs(number) / torch.abs(significance)) * torch.abs(significance),
0,
)

results = positive_number + positive_significance + negative_significance
return results


@dispatcher.register_for("FLOOR", "FLOOR.MATH", "FLOOR.PRECISE")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="hotxlfp",
version="0.0.11+unc.26",
version="0.0.11+unc.27",
packages=[
"hotxlfp",
"hotxlfp._compat",
Expand Down
4 changes: 4 additions & 0 deletions tests/test_formula_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ def test_tensors(self):
_test_equation(equation="MAX(MAX(2, a1 * 2), 100)", variables={"a1": [5, 4]}, answer=[100, 100])
_test_equation(equation="5", variables={"a1": [5, 4]}, answer=[5])
_test_equation(equation="SQRT(100)", variables={"a1": [5]}, answer=[10])
_test_equation(equation="CEILING(a1)", variables={"a1": [4.5, -1.2]}, answer=[5, -1])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [1, 2]}, answer=[1, 2])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5, 0.5], "a2": [2]}, answer=[2, 2])
_test_equation(equation="CEILING(a1, a2)", variables={"a1": [0.5], "a2": [1]}, answer=[1])

def test_scientific_notation(self):
_test_equation(equation="2e2", variables={"a1" : [1.1]}, answer=[200])
Expand Down

0 comments on commit 4828d52

Please sign in to comment.