Skip to content

Commit

Permalink
backport ParamSpecArgs/Kwargs (#798)
Browse files Browse the repository at this point in the history
From python/cpython#25298. I also added more tests for get_args/get_origin,
which previously didn't exist in in typing_extensions.
  • Loading branch information
JelleZijlstra authored Apr 13, 2021
1 parent 40932e3 commit 4ba98e8
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 14 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Workflow
Workflow for PyPI releases
--------------------------

* Run tests under all supported versions. As of May 2019 this includes
2.7, 3.4, 3.5, 3.6, 3.7.
* Run tests under all supported versions. As of April 2021 this includes
2.7, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9.

* On macOS, you can use `pyenv <https://github.com/pyenv/pyenv>`_ to
manage multiple Python installations. Long story short:
Expand Down
4 changes: 4 additions & 0 deletions typing_extensions/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ Python 3.4+ only:
-----------------

- ``ChainMap``
- ``ParamSpec``
- ``Concatenate``
- ``ParamSpecArgs``
- ``ParamSpecKwargs``

Python 3.5+ only:
-----------------
Expand Down
95 changes: 89 additions & 6 deletions typing_extensions/src_py3/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import subprocess
import types
from unittest import TestCase, main, skipUnless, skipIf
from typing import TypeVar, Optional
from typing import TypeVar, Optional, Union
from typing import T, KT, VT # Not in __all__.
from typing import Tuple, List, Dict, Iterator
from typing import Tuple, List, Dict, Iterator, Callable
from typing import Generic
from typing import no_type_check
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict
from typing_extensions import TypeAlias, ParamSpec, Concatenate
from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs

try:
from typing_extensions import Protocol, runtime, runtime_checkable
Expand Down Expand Up @@ -519,6 +519,80 @@ def test_final_forward_ref(self):
self.assertNotEqual(gth(Loop, globals())['attr'], Final)


@skipUnless(PEP_560, "Python 3.7+ required")
class GetUtilitiesTestCase(TestCase):
def test_get_origin(self):
from typing_extensions import get_origin

T = TypeVar('T')
P = ParamSpec('P')
class C(Generic[T]): pass
self.assertIs(get_origin(C[int]), C)
self.assertIs(get_origin(C[T]), C)
self.assertIs(get_origin(int), None)
self.assertIs(get_origin(ClassVar[int]), ClassVar)
self.assertIs(get_origin(Union[int, str]), Union)
self.assertIs(get_origin(Literal[42, 43]), Literal)
self.assertIs(get_origin(Final[List[int]]), Final)
self.assertIs(get_origin(Generic), Generic)
self.assertIs(get_origin(Generic[T]), Generic)
self.assertIs(get_origin(List[Tuple[T, T]][int]), list)
self.assertIs(get_origin(Annotated[T, 'thing']), Annotated)
self.assertIs(get_origin(List), list)
self.assertIs(get_origin(Tuple), tuple)
self.assertIs(get_origin(Callable), collections.abc.Callable)
if sys.version_info >= (3, 9):
self.assertIs(get_origin(list[int]), list)
self.assertIs(get_origin(list), None)
self.assertIs(get_origin(P.args), P)
self.assertIs(get_origin(P.kwargs), P)

def test_get_args(self):
from typing_extensions import get_args

T = TypeVar('T')
class C(Generic[T]): pass
self.assertEqual(get_args(C[int]), (int,))
self.assertEqual(get_args(C[T]), (T,))
self.assertEqual(get_args(int), ())
self.assertEqual(get_args(ClassVar[int]), (int,))
self.assertEqual(get_args(Union[int, str]), (int, str))
self.assertEqual(get_args(Literal[42, 43]), (42, 43))
self.assertEqual(get_args(Final[List[int]]), (List[int],))
self.assertEqual(get_args(Union[int, Tuple[T, int]][str]),
(int, Tuple[str, int]))
self.assertEqual(get_args(typing.Dict[int, Tuple[T, T]][Optional[int]]),
(int, Tuple[Optional[int], Optional[int]]))
self.assertEqual(get_args(Callable[[], T][int]), ([], int))
self.assertEqual(get_args(Callable[..., int]), (..., int))
self.assertEqual(get_args(Union[int, Callable[[Tuple[T, ...]], str]]),
(int, Callable[[Tuple[T, ...]], str]))
self.assertEqual(get_args(Tuple[int, ...]), (int, ...))
self.assertEqual(get_args(Tuple[()]), ((),))
self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three']))
self.assertEqual(get_args(List), ())
self.assertEqual(get_args(Tuple), ())
self.assertEqual(get_args(Callable), ())
if sys.version_info >= (3, 9):
self.assertEqual(get_args(list[int]), (int,))
self.assertEqual(get_args(list), ())
if sys.version_info >= (3, 9):
# Support Python versions with and without the fix for
# https://bugs.python.org/issue42195
# The first variant is for 3.9.2+, the second for 3.9.0 and 1
self.assertIn(get_args(collections.abc.Callable[[int], str]),
(([int], str), ([[int]], str)))
self.assertIn(get_args(collections.abc.Callable[[], str]),
(([], str), ([[]], str)))
self.assertEqual(get_args(collections.abc.Callable[..., str]), (..., str))
P = ParamSpec('P')
# In 3.9 and lower we use typing_extensions's hacky implementation
# of ParamSpec, which gets incorrectly wrapped in a list
self.assertIn(get_args(Callable[P, int]), [(P, int), ([P], int)])
self.assertEqual(get_args(Callable[Concatenate[int, P], int]),
(Concatenate[int, P], int))


class CollectionsAbcTests(BaseTestCase):

def test_isinstance_collections(self):
Expand Down Expand Up @@ -1952,8 +2026,17 @@ def test_valid_uses(self):
# ParamSpec instances should also have args and kwargs attributes.
self.assertIn('args', dir(P))
self.assertIn('kwargs', dir(P))
P.args
P.kwargs

def test_args_kwargs(self):
P = ParamSpec('P')
self.assertIn('args', dir(P))
self.assertIn('kwargs', dir(P))
self.assertIsInstance(P.args, ParamSpecArgs)
self.assertIsInstance(P.kwargs, ParamSpecKwargs)
self.assertIs(P.args.__origin__, P)
self.assertIs(P.kwargs.__origin__, P)
self.assertEqual(repr(P.args), "P.args")
self.assertEqual(repr(P.kwargs), "P.kwargs")

# Note: ParamSpec doesn't work for pre-3.10 user-defined Generics due
# to type checks inside Generic.
Expand Down Expand Up @@ -2072,7 +2155,7 @@ def test_typing_extensions_defers_when_possible(self):
'Final',
'get_type_hints'
}
if sys.version_info[:2] == (3, 8):
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
Expand Down
85 changes: 79 additions & 6 deletions typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,11 +2065,23 @@ class Annotated(metaclass=AnnotatedMeta):

# Python 3.8 has get_origin() and get_args() but those implementations aren't
# Annotated-aware, so we can't use those, only Python 3.9 versions will do.
if sys.version_info[:2] >= (3, 9):
# Similarly, Python 3.9's implementation doesn't support ParamSpecArgs and
# ParamSpecKwargs.
if sys.version_info[:2] >= (3, 10):
get_origin = typing.get_origin
get_args = typing.get_args
elif PEP_560:
from typing import _GenericAlias # noqa
from typing import _GenericAlias
try:
# 3.9+
from typing import _BaseGenericAlias
except ImportError:
_BaseGenericAlias = _GenericAlias
try:
# 3.9+
from typing import GenericAlias
except ImportError:
GenericAlias = _GenericAlias

def get_origin(tp):
"""Get the unsubscripted version of a type.
Expand All @@ -2084,10 +2096,12 @@ def get_origin(tp):
get_origin(Generic[T]) is Generic
get_origin(Union[T, int]) is Union
get_origin(List[Tuple[T, T]][int]) == list
get_origin(P.args) is P
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
if isinstance(tp, _GenericAlias):
if isinstance(tp, (_GenericAlias, GenericAlias, _BaseGenericAlias,
ParamSpecArgs, ParamSpecKwargs)):
return tp.__origin__
if tp is Generic:
return Generic
Expand All @@ -2106,7 +2120,9 @@ def get_args(tp):
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, _GenericAlias) and not tp._special:
if isinstance(tp, (_GenericAlias, GenericAlias)):
if getattr(tp, "_special", False):
return ()
res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
Expand Down Expand Up @@ -2210,9 +2226,60 @@ class TypeAlias(metaclass=_TypeAliasMeta, _root=True):


# Python 3.10+ has PEP 612
if hasattr(typing, 'ParamSpecArgs'):
ParamSpecArgs = typing.ParamSpecArgs
ParamSpecKwargs = typing.ParamSpecKwargs
else:
class _Immutable:
"""Mixin to indicate that object should not be copied."""
__slots__ = ()

def __copy__(self):
return self

def __deepcopy__(self, memo):
return self

class ParamSpecArgs(_Immutable):
"""The args for a ParamSpec object.
Given a ParamSpec object P, P.args is an instance of ParamSpecArgs.
ParamSpecArgs objects have a reference back to their ParamSpec:
P.args.__origin__ is P
This type is meant for runtime introspection and has no special meaning to
static type checkers.
"""
def __init__(self, origin):
self.__origin__ = origin

def __repr__(self):
return "{}.args".format(self.__origin__.__name__)

class ParamSpecKwargs(_Immutable):
"""The kwargs for a ParamSpec object.
Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs.
ParamSpecKwargs objects have a reference back to their ParamSpec:
P.kwargs.__origin__ is P
This type is meant for runtime introspection and has no special meaning to
static type checkers.
"""
def __init__(self, origin):
self.__origin__ = origin

def __repr__(self):
return "{}.kwargs".format(self.__origin__.__name__)

if hasattr(typing, 'ParamSpec'):
ParamSpec = typing.ParamSpec
else:

# Inherits from list as a workaround for Callable checks in Python < 3.9.2.
class ParamSpec(list):
"""Parameter specification variable.
Expand Down Expand Up @@ -2260,8 +2327,14 @@ def add_two(x: float, y: float) -> float:
Note that only parameter specification variables defined in global scope can
be pickled.
"""
args = object()
kwargs = object()

@property
def args(self):
return ParamSpecArgs(self)

@property
def kwargs(self):
return ParamSpecKwargs(self)

def __init__(self, name, *, bound=None, covariant=False, contravariant=False):
super().__init__([self])
Expand Down

0 comments on commit 4ba98e8

Please sign in to comment.