Skip to content

Commit

Permalink
feat: add BigO analysis (#90)
Browse files Browse the repository at this point in the history
* feat: add BigO analysis

* docs: add analysis to api reference

* fix: add warnings for BigO

* fix: improve warnings for BigO
  • Loading branch information
mstechly authored Jul 23, 2024
1 parent a05acc8 commit 7a24848
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

::: bartiq

::: bartiq.analysis

::: bartiq.precompilation

::: bartiq.symbolics
Expand Down
119 changes: 119 additions & 0 deletions src/bartiq/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Optional

from sympy import Expr, Function, Poly, Symbol, prod


class BigO:
def __init__(self, expr: Expr, variable: Optional[Symbol] = None):
"""Class for representing expressions in Big O notation.
It analyzes given expression and returns all the Big O terms in it.
If variable is provided, it analyses scaling in this particular variable,
otherwise it assumes all the symbols are variables.
Note:
It's an experimental tool and is meant to facilitate the analysis, but
it might not produce correct results, especially for more complicated
expressions. In case of any problems please create an issue on project's GitHub,
we'd love to hear your feedback on this!
Args:
expr: sympy expression we want to analyze
variable: variable for which we want to performa analysis.
"""
if variable is None:
gens = []
else:
gens = [variable]
self.expr = _convert_to_big_O(expr, gens)

def __add__(self, other):
if isinstance(other, self.__class__):
return BigO(_remove_big_O_function(self.expr) + _remove_big_O_function(other.expr))
else:
return BigO(_remove_big_O_function(self.expr) + _remove_big_O_function(other))

def __eq__(self, other):
return self.expr == other.expr

def __mul__(self, other):
if isinstance(other, self.__class__):
return BigO(_remove_big_O_function(self.expr) * _remove_big_O_function(other.expr))
else:
return BigO(_remove_big_O_function(self.expr) * _remove_big_O_function(other))

def __repr__(self) -> str:
return f"{self.expr}"


def _remove_big_O_function(expr: Expr) -> Expr:
args = expr.args
new_args = []
for arg in args:
if isinstance(arg, Function("O")):
assert len(arg.args) == 1
new_args.append(arg.args[0])
else:
new_args.append(arg)
return sum(new_args)


def _add_big_o_function(expr: Expr) -> Expr:
if isinstance(expr, Function("O")):
return expr
return Function("O")(expr)


def _convert_to_big_O(expr: Expr, gens: Optional[list[Expr]] = None) -> Expr:
gens = gens or []
if len(expr.free_symbols) == 0:
return _add_big_o_function(1)
if len(expr.free_symbols) > 1 and len(gens) == 0:
warnings.warn(
"Results for using BigO with multiple variables might be unreliable. "
"For better results please select a variable of interest."
)
poly = Poly(expr, *gens)
leading_terms = _get_leading_terms(poly)
return sum(map(_add_big_o_function, leading_terms))


def _get_leading_terms(poly):
terms, _ = zip(*poly.terms())
leading_terms = []
for term in terms:
if not _term_less_than_or_equal_to_all_others(term, leading_terms):
leading_terms.append(term)

return [_make_term_expression(poly.gens, leading_term) for leading_term in leading_terms]


def _term_less_than_or_equal_to_all_others(candidate, other_terms):
if not other_terms:
return False

return all(_less_than(candidate, term) for term in other_terms)


def _less_than(term_1, term_2):
return all(a <= b for a, b in zip(term_1, term_2))


def _make_term_expression(gens, term):
powers = [gen**order for gen, order in zip(gens, term)]
return prod(powers)
92 changes: 92 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import sympy
from sympy.abc import x, y

from bartiq.analysis import BigO


@pytest.mark.parametrize(
"expr,variable,expected",
[
(
x**y + y**x + x * y + x**2 * y + 3 + x * y**2 + x + y + 1,
None,
BigO(x**y) + BigO(y**x) + BigO(x * y**2) + BigO(x**2 * y),
),
(
x * y + x**2 * y + 3 + x * y**3 + x + y + 1,
x,
BigO(x**2),
),
(
x * y + x**2 * y + 3 + x * y**3 + x + y + 1,
y,
BigO(y**3),
),
(
sympy.sympify("f(x) + y"),
y,
BigO(y),
),
(
sympy.sympify("log(x) + y**2 + y"),
None,
BigO(sympy.sympify("log(x)")) + BigO(y**2),
),
],
)
def test_BigO(expr, variable, expected):
assert BigO(expr, variable) == expected


def test_BigO_throws_warning_for_multiple_variables():
with pytest.warns(
match="Results for using BigO with multiple variables might be unreliable. "
"For better results please select a variable of interest."
):
BigO(x**y + y**x + x * y + x**2 * y + 3 + x * y**2 + x + y + 1)


def test_adding_BigO_expressions():
assert BigO(x) + BigO(x) == BigO(x)
assert BigO(x) * BigO(x) == BigO(x**2)
assert BigO(2 * y + 17) + BigO(x - 1) == BigO(x) + BigO(y)
assert BigO(x, variable=y) + BigO(y, variable=x) == BigO(sympy.sympify(1))


@pytest.mark.parametrize(
"expr,gens,expected",
[
(
x**0.5,
x,
BigO(x**0.5),
),
(
sympy.log(x) + sympy.log(x) * x,
x,
BigO(sympy.log(x) * x),
),
(
sympy.log(x) + sympy.log(x) * x,
x,
BigO(sympy.log(x) * x),
),
],
)
def test_failing_big_O_cases(expr, gens, expected):
pytest.xfail()

0 comments on commit 7a24848

Please sign in to comment.