Skip to content

Commit

Permalink
Remove a couple type ignores + cleanup CacheMappingView (#657)
Browse files Browse the repository at this point in the history
* Work around a MyPy inference (? or typeshed) bug

* Clean up CacheMappingView

* A few more type ignores

* Please flake8

* Revert isinstance checks

* Use TypeGuards instead

Other two type ignores here are pending update of typeshed in mypy:
python/typeshed#5473

* Code review feedback

Co-authored-by: davfsa <[email protected]>
Co-authored-by: FasterSpeeding <[email protected]>

* Use mypy attrs plugin more effectively

* Fix merge conflicts

* Remove redundancy in config

* Re-order type overloads and abstract method decorators

Now, this PR follows the convention found elsewhere in the library.
(In hikari/api/rest.py)

* Actually run converters, as well as have a docstring for ssl

* davfsa patch to cache, with a couple changes

* Don't run converters on setattr

* Fix Cache3DMappingView

(why isn't mypy warning me about this!!)

* Remove transparent mapping

* Remove redundant else clause

* Reformat

Co-authored-by: davfsa <[email protected]>
Co-authored-by: FasterSpeeding <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2021
1 parent bf99770 commit 88dce72
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 87 deletions.
14 changes: 12 additions & 2 deletions hikari/api/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,19 @@ class CacheView(typing.Mapping[_KeyT, _ValueT], abc.ABC):

__slots__: typing.Sequence[str] = ()

@typing.overload
@abc.abstractmethod
def get_item_at(self, index: int) -> _ValueT:
"""Get an entry in the view at position `index`."""
def get_item_at(self, index: int, /) -> _ValueT:
...

@typing.overload
@abc.abstractmethod
def get_item_at(self, index: slice, /) -> typing.Sequence[_ValueT]:
...

@abc.abstractmethod
def get_item_at(self, index: typing.Union[slice, int], /) -> typing.Union[_ValueT, typing.Sequence[_ValueT]]:
...

@abc.abstractmethod
def iterator(self) -> iterators.LazyIterator[_ValueT]:
Expand Down
70 changes: 31 additions & 39 deletions hikari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,55 +330,47 @@ def _(self, _: attr.Attribute[typing.Optional[int]], value: typing.Optional[int]
if value is not None and (not isinstance(value, int) or value <= 0):
raise ValueError("http_settings.max_redirects must be None or a POSITIVE integer")

_ssl: typing.Union[bool, ssl_.SSLContext] = attr.field(
default=True,
ssl: ssl_.SSLContext = attr.field(
factory=lambda: _ssl_factory(True),
converter=_ssl_factory,
validator=attr.validators.instance_of(ssl_.SSLContext), # type: ignore[assignment,arg-type]
validator=attr.validators.instance_of(ssl_.SSLContext),
)
"""SSL context to use.
@property
def ssl(self) -> ssl_.SSLContext:
"""SSL context to use.
This may be __assigned__ a `builtins.bool` or an `ssl.SSLContext` object.
This may be __assigned__ a `builtins.bool` or an `ssl.SSLContext` object.
If assigned to `builtins.True`, a default SSL context is generated by
this class that will enforce SSL verification. This is then stored in
this field.
If assigned to `builtins.True`, a default SSL context is generated by
this class that will enforce SSL verification. This is then stored in
this field.
If `builtins.False`, then a default SSL context is generated by this
class that will **NOT** enforce SSL verification. This is then stored
in this field.
If `builtins.False`, then a default SSL context is generated by this
class that will **NOT** enforce SSL verification. This is then stored
in this field.
If an instance of `ssl.SSLContext`, then this context will be used.
If an instance of `ssl.SSLContext`, then this context will be used.
!!! warning
Setting a custom value here may have security implications, or
may result in the application being unable to connect to Discord
at all.
!!! warning
Setting a custom value here may have security implications, or
may result in the application being unable to connect to Discord
at all.
!!! warning
Disabling SSL verification is almost always unadvised. This
is because your application will no longer check whether you are
connecting to Discord, or to some third party spoof designed
to steal personal credentials such as your application token.
!!! warning
Disabling SSL verification is almost always unadvised. This
is because your application will no longer check whether you are
connecting to Discord, or to some third party spoof designed
to steal personal credentials such as your application token.
There may be cases where SSL certificates do not get updated,
and in this case, you may find that disabling this explicitly
allows you to work around any issues that are occurring, but
you should immediately seek a better solution where possible
if any form of personal security is in your interest.
There may be cases where SSL certificates do not get updated,
and in this case, you may find that disabling this explicitly
allows you to work around any issues that are occurring, but
you should immediately seek a better solution where possible
if any form of personal security is in your interest.
Returns
-------
ssl.SSLContext
The SSL context to use for this application.
"""
ssl = self._ssl
assert isinstance(
ssl, ssl_.SSLContext
), f"expected ssl.SSLContext, found {type(ssl)!r}. Did you overwrite the value?"
return ssl
Returns
-------
ssl.SSLContext
The SSL context to use for this application.
"""

timeouts: HTTPTimeoutSettings = attr.field(
factory=HTTPTimeoutSettings, validator=attr.validators.instance_of(HTTPTimeoutSettings)
Expand Down
2 changes: 1 addition & 1 deletion hikari/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ async def _wrap_iter(self) -> typing.AsyncGenerator[typing.Any, bytes]:
pass

elif aio.is_async_iterable(self.data):
async for chunk in self.data: # type: ignore[union-attr]
async for chunk in self.data:
yield self._assert_bytes(chunk)

elif isinstance(self.data, typing.Iterable):
Expand Down
8 changes: 6 additions & 2 deletions hikari/internal/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
import inspect
import typing

if typing.TYPE_CHECKING:
# typing_extensions is a dependency of mypy, and pyright vendors it.
from typing_extensions import TypeGuard

T_co = typing.TypeVar("T_co", covariant=True)
T_inv = typing.TypeVar("T_inv")

Expand Down Expand Up @@ -84,12 +88,12 @@ def completed_future(result: typing.Optional[T_inv] = None, /) -> asyncio.Future
# ... so I guess I will have to determine this some other way.


def is_async_iterator(obj: typing.Any) -> bool:
def is_async_iterator(obj: typing.Any) -> TypeGuard[typing.AsyncIterator[object]]:
"""Determine if the object is an async iterator or not."""
return asyncio.iscoroutinefunction(getattr(obj, "__anext__", None))


def is_async_iterable(obj: typing.Any) -> bool:
def is_async_iterable(obj: typing.Any) -> TypeGuard[typing.AsyncIterable[object]]:
"""Determine if the object is an async iterable or not."""
attr = getattr(obj, "__aiter__", None)
return inspect.isfunction(attr) or inspect.ismethod(attr)
Expand Down
83 changes: 40 additions & 43 deletions hikari/internal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,80 +83,77 @@
"""Type-hint for mapping values."""


class CacheMappingView(cache.CacheView[KeyT, ValueT], typing.Generic[KeyT, ValueT]):
class CacheMappingView(cache.CacheView[KeyT, ValueT]):
"""A cache mapping view implementation used for representing cached data.
Parameters
----------
items : typing.Mapping[KeyT, typing.Union[ValueT, DataT, RefCell[ValueT]]]
items : typing.Union[typing.Mapping[KeyT, ValueT], typing.Mapping[KeyT, DataT]]
A mapping of keys to the values in their raw forms, wrapped by a ref
wrapper or in a data form.
builder : typing.Optional[typing.Callable[[DataT], ValueT]]
The callable used to build entities before they're returned by the
mapping. This is used to cover the case when items stores `DataT` objects.
predicate : typing.Optional[typing.Callable[[typing.Any], bool]]
A callable to use to determine whether entries should be returned or hidden,
this should take in whatever raw type was passed for the value in `items`.
This may be `builtins.None` if all entries should be exposed.
"""

__slots__: typing.Sequence[str] = ("_builder", "_data", "_predicate")
__slots__: typing.Sequence[str] = ("_data", "_builder")

@typing.overload
def __init__(
self,
items: typing.Mapping[KeyT, typing.Union[ValueT, DataT]],
items: typing.Mapping[KeyT, ValueT],
) -> None:
...

@typing.overload
def __init__(
self,
items: typing.Mapping[KeyT, DataT],
*,
builder: typing.Callable[[DataT], ValueT],
) -> None:
...

def __init__(
self,
items: typing.Union[typing.Mapping[KeyT, ValueT], typing.Mapping[KeyT, DataT]],
*,
builder: typing.Optional[typing.Callable[[DataT], ValueT]] = None,
predicate: typing.Optional[typing.Callable[[typing.Any], bool]] = None,
) -> None:
self._builder = builder
self._data = items
self._predicate = predicate

@classmethod
def _copy(cls, value: ValueT) -> ValueT:
@staticmethod
def _copy(value: ValueT) -> ValueT:
return copy.copy(value)

def __contains__(self, key: typing.Any) -> bool:
return key in self._data and (self._predicate is None or self._predicate(self._data[key]))
return key in self._data

def __getitem__(self, key: KeyT) -> ValueT:
entry = self._data[key]

if self._predicate is not None and not self._predicate(entry):
raise KeyError(key)

if self._builder is not None:
entry = self._builder(entry) # type: ignore[arg-type]

else:
entry = self._copy(entry) # type: ignore[arg-type]
if self._builder:
return self._builder(entry) # type: ignore[arg-type]

return entry
return self._copy(entry) # type: ignore[arg-type]

def __iter__(self) -> typing.Iterator[KeyT]:
if self._predicate is None:
return iter(self._data)
else:
return (key for key, value in self._data.items() if self._predicate(value))
return iter(self._data)

def __len__(self) -> int:
if self._predicate is None:
return len(self._data)
else:
return sum(1 for value in self._data.values() if self._predicate(value))
return len(self._data)

def get_item_at(self, index: int) -> ValueT:
current_index = -1
@typing.overload
def get_item_at(self, index: int, /) -> ValueT:
...

for key, value in self._data.items():
if self._predicate is None or self._predicate(value):
index += 1
@typing.overload
def get_item_at(self, index: slice, /) -> typing.Sequence[ValueT]:
...

if current_index == index:
return self[key]

raise IndexError(index)
def get_item_at(self, index: typing.Union[slice, int], /) -> typing.Union[ValueT, typing.Sequence[ValueT]]:
return collections.get_index_or_slice(self, index)

def iterator(self) -> iterators.LazyIterator[ValueT]:
return iterators.FlatLazyIterator(self.values())
Expand All @@ -179,7 +176,7 @@ def __iter__(self) -> typing.Iterator[typing.Any]:
def __len__(self) -> typing.Literal[0]:
return 0

def get_item_at(self, index: int) -> typing.NoReturn:
def get_item_at(self, index: typing.Union[slice, int]) -> typing.NoReturn:
raise IndexError(index)

def iterator(self) -> iterators.LazyIterator[ValueT]:
Expand Down Expand Up @@ -740,7 +737,7 @@ def _copy_embed(embed: embeds_.Embed) -> embeds_.Embed:
author=copy.copy(embed.author) if embed.author else None,
provider=copy.copy(embed.provider) if embed.provider else None,
footer=copy.copy(embed.footer) if embed.footer else None,
fields=list(map(copy.copy, embed.fields)), # type: ignore[arg-type]
fields=[copy.copy(field) for field in embed.fields],
)


Expand Down Expand Up @@ -1028,6 +1025,6 @@ class Cache3DMappingView(CacheMappingView[snowflakes.Snowflake, cache.CacheView[

__slots__: typing.Sequence[str] = ()

@classmethod
def _copy(cls, value: cache.CacheView[KeyT, ValueT]) -> cache.CacheView[KeyT, ValueT]:
@staticmethod
def _copy(value: cache.CacheView[KeyT, ValueT]) -> cache.CacheView[KeyT, ValueT]:
return value

0 comments on commit 88dce72

Please sign in to comment.