Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn when doing O(n) operations on RedisSets #515

Merged
merged 2 commits into from
Dec 9, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions pottery/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import collections.abc
import itertools
import warnings
from typing import Any
from typing import Iterable
from typing import List
Expand All @@ -34,6 +35,7 @@
from .base import Base
from .base import Iterable_
from .base import JSONTypes
from .exceptions import InefficientAccessWarning
from .exceptions import KeyExistsError


Expand Down Expand Up @@ -73,6 +75,10 @@ def __contains__(self, value: Any) -> bool:
return False

def _scan(self, *, cursor: int = 0) -> Tuple[int, List[bytes]]:
warnings.warn(
cast(str, InefficientAccessWarning.__doc__),
InefficientAccessWarning,
)
return self.redis.sscan(self.key, cursor=cursor)

def __len__(self) -> int:
Expand All @@ -91,6 +97,10 @@ def discard(self, value: JSONTypes) -> None:

def __repr__(self) -> str:
'Return the string representation of the RedisSet. O(n)'
warnings.warn(
cast(str, InefficientAccessWarning.__doc__),
InefficientAccessWarning,
)
set_ = {self._decode(value) for value in self.redis.smembers(self.key)}
return self.__class__.__name__ + str(set_)

Expand All @@ -117,20 +127,32 @@ def isdisjoint(self, other: Iterable[Any]) -> bool:

# Where does this method come from?
def issubset(self, other: Iterable[Any]) -> bool:
with self._watch(other):
if not isinstance(other, collections.abc.Set):
other = frozenset(other)
return self <= other
'Report whether another set contains this set. O(n)'
return self.__sub_or_super(other, set_method='__le__')

# Where does this method come from?
def issuperset(self, other: Iterable[Any]) -> bool:
'Report whether this set contains another set. O(n)'
return self.__sub_or_super(other, set_method='__ge__')

def __sub_or_super(self,
other: Iterable[Any],
*,
set_method: Literal['__le__', '__ge__'],
) -> bool:
warnings.warn(
cast(str, InefficientAccessWarning.__doc__),
InefficientAccessWarning,
)
with self._watch(other):
if not isinstance(other, collections.abc.Set):
other = frozenset(other)
return self >= other
method = getattr(self, set_method)
return method(other)

# Where does this method come from?
def union(self, *others: Iterable[Any]) -> Set[Any]:
'Return the union of sets as a new set. O(n)'
return self.__set_op(
*others,
redis_method='sunion',
Expand All @@ -139,6 +161,7 @@ def union(self, *others: Iterable[Any]) -> Set[Any]:

# Where does this method come from?
def intersection(self, *others: Iterable[Any]) -> Set[Any]:
'Return the intersection of two sets as a new set. O(n)'
return self.__set_op(
*others,
redis_method='sinter',
Expand All @@ -152,6 +175,7 @@ def intersection(self, *others: Iterable[Any]) -> Set[Any]:

# Where does this method come from?
def difference(self, *others: Iterable[Any]) -> Set[Any]:
'Return the difference of two or more sets as a new set. O(n)'
return self.__set_op(
*others,
redis_method='sdiff',
Expand All @@ -163,6 +187,10 @@ def __set_op(self,
redis_method: Literal['sunion', 'sinter', 'sdiff'],
set_method: Literal['union', 'intersection', 'difference'],
) -> Set[Any]:
warnings.warn(
cast(str, InefficientAccessWarning.__doc__),
InefficientAccessWarning,
)
if self._same_redis(*others):
method = getattr(self.redis, redis_method)
keys = (self.key, *(cast(RedisSet, other).key for other in others))
Expand All @@ -182,6 +210,7 @@ def symmetric_difference(self, other: Iterable[Any]) -> NoReturn: # pragma: no

# Where does this method come from?
def update(self, *others: Iterable[JSONTypes]) -> None:
'Update a set with the union of itself and others. O(n)'
self.__update(
*others,
redis_method='sunionstore',
Expand All @@ -194,6 +223,7 @@ def intersection_update(self, *others: Iterable[JSONTypes]) -> NoReturn: # prag

# Where does this method come from?
def difference_update(self, *others: Iterable[JSONTypes]) -> None:
'Remove all elements of another set from this set. O(n)'
self.__update(
*others,
redis_method='sdiffstore',
Expand All @@ -207,6 +237,10 @@ def __update(self,
) -> None:
if not others:
return
warnings.warn(
cast(str, InefficientAccessWarning.__doc__),
InefficientAccessWarning,
)
if self._same_redis(*others):
method = getattr(self.redis, redis_method)
keys = (
Expand Down