Skip to content

Commit

Permalink
Add get_protocol_members and is_protocol (#238)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
JelleZijlstra and AlexWaygood authored Jun 16, 2023
1 parent f9b83a2 commit 38bb6e8
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 14 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Unreleased

- Add `typing_extensions.get_protocol_members` and
`typing_extensions.is_protocol` (backport of CPython PR #104878).
Patch by Jelle Zijlstra.
- `typing_extensions` now re-exports all names in the standard library's
`typing` module, except the deprecated `ByteString`. Patch by Jelle
Zijlstra.
Expand All @@ -17,7 +20,7 @@
- Fix tests on Python 3.13, which removes support for creating
`TypedDict` classes through the keyword-argument syntax. Patch by
Jelle Zijlstra.
- Fix a regression introduced in v4.6.3 that meant that
- Fix a regression introduced in v4.6.3 that meant that
``issubclass(object, typing_extensions.Protocol)`` would erroneously raise
``TypeError``. Patch by Alex Waygood (backporting the CPython PR
https://github.com/python/cpython/pull/105239).
Expand Down
34 changes: 34 additions & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,24 @@ Functions

.. versionadded:: 4.2.0

.. function:: get_protocol_members(tp)

Return the set of members defined in a :class:`Protocol`. This works with protocols
defined using either :class:`typing.Protocol` or :class:`typing_extensions.Protocol`.

::

>>> from typing_extensions import Protocol, get_protocol_members
>>> class P(Protocol):
... def a(self) -> str: ...
... b: int
>>> get_protocol_members(P)
frozenset({'a', 'b'})

Raise :py:exc:`TypeError` for arguments that are not Protocols.

.. versionadded:: 4.7.0

.. function:: get_type_hints(obj, globalns=None, localns=None, include_extras=False)

See :py:func:`typing.get_type_hints`.
Expand All @@ -634,6 +652,22 @@ Functions

Interaction with :data:`Required` and :data:`NotRequired`.

.. function:: is_protocol(tp)

Determine if a type is a :class:`Protocol`. This works with protocols
defined using either :py:class:`typing.Protocol` or :class:`typing_extensions.Protocol`.

For example::

class P(Protocol):
def a(self) -> str: ...
b: int

is_protocol(P) # => True
is_protocol(int) # => False

.. versionadded:: 4.7.0

.. function:: is_typeddict(tp)

See :py:func:`typing.is_typeddict`. In ``typing`` since 3.10.
Expand Down
126 changes: 113 additions & 13 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from typing_extensions import assert_type, get_type_hints, get_origin, get_args, get_original_bases
from typing_extensions import clear_overloads, get_overloads, overload
from typing_extensions import NamedTuple
from typing_extensions import override, deprecated, Buffer, TypeAliasType, TypeVar
from typing_extensions import override, deprecated, Buffer, TypeAliasType, TypeVar, get_protocol_members, is_protocol
from _typed_dict_test_helper import Foo, FooGeneric, VeryAnnotated

# Flags used to mark tests that only apply after a specific
Expand All @@ -52,6 +52,10 @@
# 3.12 changes the representation of Unpack[] (PEP 692)
TYPING_3_12_0 = sys.version_info[:3] >= (3, 12, 0)

only_with_typing_Protocol = skipUnless(
hasattr(typing, "Protocol"), "Only relevant when typing.Protocol exists"
)

# https://github.com/python/cpython/pull/27017 was backported into some 3.9 and 3.10
# versions, but not all
HAS_FORWARD_MODULE = "module" in inspect.signature(typing._type_check).parameters
Expand Down Expand Up @@ -1767,10 +1771,7 @@ class E(C, BP): pass
self.assertNotIsInstance(D(), E)
self.assertNotIsInstance(E(), D)

@skipUnless(
hasattr(typing, "Protocol"),
"Test is only relevant if typing.Protocol exists"
)
@only_with_typing_Protocol
def test_runtimecheckable_on_typing_dot_Protocol(self):
@runtime_checkable
class Foo(typing.Protocol):
Expand All @@ -1783,10 +1784,7 @@ def __init__(self):
self.assertIsInstance(Bar(), Foo)
self.assertNotIsInstance(object(), Foo)

@skipUnless(
hasattr(typing, "runtime_checkable"),
"Test is only relevant if typing.runtime_checkable exists"
)
@only_with_typing_Protocol
def test_typing_dot_runtimecheckable_on_Protocol(self):
@typing.runtime_checkable
class Foo(Protocol):
Expand All @@ -1799,10 +1797,7 @@ def __init__(self):
self.assertIsInstance(Bar(), Foo)
self.assertNotIsInstance(object(), Foo)

@skipUnless(
hasattr(typing, "Protocol"),
"Test is only relevant if typing.Protocol exists"
)
@only_with_typing_Protocol
def test_typing_Protocol_and_extensions_Protocol_can_mix(self):
class TypingProto(typing.Protocol):
x: int
Expand Down Expand Up @@ -2992,6 +2987,111 @@ def __call__(self, *args: Unpack[Ts]) -> T: ...
self.assertEqual(Y.__parameters__, ())
self.assertEqual(Y.__args__, (int, bytes, memoryview))

def test_get_protocol_members(self):
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(object)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(object())
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Protocol)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Generic)

class P(Protocol):
a: int
def b(self) -> str: ...
@property
def c(self) -> int: ...

self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'})
self.assertIsInstance(get_protocol_members(P), frozenset)
self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)

class Concrete:
a: int
def b(self) -> str: return "capybara"
@property
def c(self) -> int: return 5

with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Concrete)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Concrete())

class ConcreteInherit(P):
a: int = 42
def b(self) -> str: return "capybara"
@property
def c(self) -> int: return 5

with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(ConcreteInherit)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(ConcreteInherit())

@only_with_typing_Protocol
def test_get_protocol_members_typing(self):
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(typing.Protocol)

class P(typing.Protocol):
a: int
def b(self) -> str: ...
@property
def c(self) -> int: ...

self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'})
self.assertIsInstance(get_protocol_members(P), frozenset)
if hasattr(P, "__protocol_attrs__"):
self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)

class Concrete:
a: int
def b(self) -> str: return "capybara"
@property
def c(self) -> int: return 5

with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Concrete)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Concrete())

class ConcreteInherit(P):
a: int = 42
def b(self) -> str: return "capybara"
@property
def c(self) -> int: return 5

with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(ConcreteInherit)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(ConcreteInherit())

def test_is_protocol(self):
self.assertTrue(is_protocol(Proto))
self.assertTrue(is_protocol(Point))
self.assertFalse(is_protocol(Concrete))
self.assertFalse(is_protocol(Concrete()))
self.assertFalse(is_protocol(Generic))
self.assertFalse(is_protocol(object))

# Protocol is not itself a protocol
self.assertFalse(is_protocol(Protocol))

@only_with_typing_Protocol
def test_is_protocol_with_typing(self):
self.assertFalse(is_protocol(typing.Protocol))

class TypingProto(typing.Protocol):
a: int

self.assertTrue(is_protocol(TypingProto))

class Concrete(TypingProto):
a: int

self.assertFalse(is_protocol(Concrete))

@skip_if_py312b1
def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self):
# Ensure the cache is empty, or this test won't work correctly
Expand Down
47 changes: 47 additions & 0 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@
'get_args',
'get_origin',
'get_original_bases',
'get_protocol_members',
'get_type_hints',
'IntVar',
'is_protocol',
'is_typeddict',
'Literal',
'NewType',
Expand Down Expand Up @@ -2902,6 +2904,51 @@ def __ror__(self, left):
return typing.Union[left, self]


if hasattr(typing, "is_protocol"):
is_protocol = typing.is_protocol
get_protocol_members = typing.get_protocol_members
else:
def is_protocol(__tp: type) -> bool:
"""Return True if the given type is a Protocol.
Example::
>>> from typing_extensions import Protocol, is_protocol
>>> class P(Protocol):
... def a(self) -> str: ...
... b: int
>>> is_protocol(P)
True
>>> is_protocol(int)
False
"""
return (
isinstance(__tp, type)
and getattr(__tp, '_is_protocol', False)
and __tp != Protocol
)

def get_protocol_members(__tp: type) -> typing.FrozenSet[str]:
"""Return the set of members defined in a Protocol.
Example::
>>> from typing_extensions import Protocol, get_protocol_members
>>> class P(Protocol):
... def a(self) -> str: ...
... b: int
>>> get_protocol_members(P)
frozenset({'a', 'b'})
Raise a TypeError for arguments that are not Protocols.
"""
if not is_protocol(__tp):
raise TypeError(f'{__tp!r} is not a Protocol')
if hasattr(__tp, '__protocol_attrs__'):
return frozenset(__tp.__protocol_attrs__)
return frozenset(_get_protocol_attrs(__tp))


# Aliases for items that have always been in typing.
# Explicitly assign these (rather than using `from typing import *` at the top),
# so that we get a CI error if one of these is deleted from typing.py
Expand Down

0 comments on commit 38bb6e8

Please sign in to comment.