Skip to content

Commit

Permalink
Pickling of generic annotations/types in 3.5+ (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
valtron authored Apr 23, 2020
1 parent d8452cc commit 1ba10a5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 117 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ jobs:
python -m pip install -r dev-requirements.txt
python ci/install_coverage_subprocess_pth.py
export
- name: Install optional typing_extensions in Python 3.6
shell: bash
run: python -m pip install typing-extensions
if: matrix.python_version == '3.6'
- name: Display Python version
shell: bash
run: python -c "import sys; print(sys.version)"
Expand Down
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

**This version requires Python 3.5 or later**

- cloudpickle can now all pickle all constructs from the ``typing`` module
and the ``typing_extensions`` library in Python 3.5+
([PR #318](https://github.com/cloudpipe/cloudpickle/pull/318))

- Stop pickling the annotations of a dynamic class for Python < 3.6
(follow up on #276)
([issue #347](https://github.com/cloudpipe/cloudpickle/issues/347))
Expand Down
91 changes: 65 additions & 26 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,18 @@
import typing
from enum import Enum

from typing import Generic, Union, Tuple, Callable, ClassVar
from pickle import _Pickler as Pickler
from pickle import _getattribute
from io import BytesIO
from importlib._bootstrap import _find_spec

try: # pragma: no branch
import typing_extensions as _typing_extensions
from typing_extensions import Literal, Final
except ImportError:
_typing_extensions = Literal = Final = None


# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
Expand Down Expand Up @@ -117,7 +124,18 @@ def _whichmodule(obj, name):
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
module_name = _get_module_attr(obj)
if sys.version_info[:2] < (3, 7) and isinstance(obj, typing.TypeVar): # pragma: no branch # noqa
# Workaround bug in old Python versions: prior to Python 3.7,
# T.__module__ would always be set to "typing" even when the TypeVar T
# would be defined in a different module.
#
# For such older Python versions, we ignore the __module__ attribute of
# TypeVar instances and instead exhaustively lookup those instances in
# all currently imported modules.
module_name = None
else:
module_name = getattr(obj, '__module__', None)

if module_name is not None:
return module_name
# Protect the iteration by using a copy of sys.modules against dynamic
Expand All @@ -140,23 +158,6 @@ def _whichmodule(obj, name):
return None


if sys.version_info[:2] < (3, 7): # pragma: no branch
# Workaround bug in old Python versions: prior to Python 3.7, T.__module__
# would always be set to "typing" even when the TypeVar T would be defined
# in a different module.
#
# For such older Python versions, we ignore the __module__ attribute of
# TypeVar instances and instead exhaustively lookup those instances in all
# currently imported modules via the _whichmodule function.
def _get_module_attr(obj):
if isinstance(obj, typing.TypeVar):
return None
return getattr(obj, '__module__', None)
else:
def _get_module_attr(obj):
return getattr(obj, '__module__', None)


def _is_importable_by_name(obj, name=None):
"""Determine if obj can be pickled as attribute of a file-backed module"""
return _lookup_module_and_qualname(obj, name=name) is not None
Expand Down Expand Up @@ -423,6 +424,18 @@ def _extract_class_dict(cls):
return clsdict


if sys.version_info[:2] < (3, 7): # pragma: no branch
def _is_parametrized_type_hint(obj):
# This is very cheap but might generate false positives.
origin = getattr(obj, '__origin__', None) # typing Constructs
values = getattr(obj, '__values__', None) # typing_extensions.Literal
type_ = getattr(obj, '__type__', None) # typing_extensions.Final
return origin is not None or values is not None or type_ is not None

def _create_parametrized_type_hint(origin, args):
return origin[args]


class CloudPickler(Pickler):

dispatch = Pickler.dispatch.copy()
Expand Down Expand Up @@ -611,11 +624,6 @@ def save_dynamic_class(self, obj):
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__

if sys.version_info < (3, 7):
# Although annotations were added in Python 3.4, It is not possible
# to properly pickle them until Python 3.7. (See #193)
clsdict.pop('__annotations__', None)

save = self.save
write = self.write

Expand Down Expand Up @@ -715,9 +723,7 @@ def save_function_tuple(self, func):
'doc': func.__doc__,
'_cloudpickle_submodules': submodules
}
if hasattr(func, '__annotations__') and sys.version_info >= (3, 7):
# Although annotations were added in Python3.4, It is not possible
# to properly pickle them until Python3.7. (See #193)
if hasattr(func, '__annotations__'):
state['annotations'] = func.__annotations__
if hasattr(func, '__qualname__'):
state['qualname'] = func.__qualname__
Expand Down Expand Up @@ -800,6 +806,14 @@ def save_global(self, obj, name=None, pack=struct.pack):
elif obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)

if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
# Parametrized typing constructs in Python < 3.7 are not compatible
# with type checks and ``isinstance`` semantics. For this reason,
# it is easier to detect them using a duck-typing-based check
# (``_is_parametrized_type_hint``) than to populate the Pickler's
# dispatch with type-specific savers.
self._save_parametrized_type_hint(obj)
elif name is not None:
Pickler.save_global(self, obj, name=name)
elif not _is_importable_by_name(obj, name=name):
Expand Down Expand Up @@ -941,6 +955,31 @@ def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
pass

if sys.version_info < (3, 7): # pragma: no branch
def _save_parametrized_type_hint(self, obj):
# The distorted type check sematic for typing construct becomes:
# ``type(obj) is type(TypeHint)``, which means "obj is a
# parametrized TypeHint"
if type(obj) is type(Literal): # pragma: no branch
initargs = (Literal, obj.__values__)
elif type(obj) is type(Final): # pragma: no branch
initargs = (Final, obj.__type__)
elif type(obj) is type(ClassVar):
initargs = (ClassVar, obj.__type__)
elif type(obj) in [type(Union), type(Tuple), type(Generic)]:
initargs = (obj.__origin__, obj.__args__)
elif type(obj) is type(Callable):
args = obj.__args__
if args[0] is Ellipsis:
initargs = (obj.__origin__, args)
else:
initargs = (obj.__origin__, (list(args[:-1]), args[-1]))
else: # pragma: no cover
raise pickle.PicklingError(
"Cloudpickle Error: Unknown type {}".format(type(obj))
)
self.save_reduce(_create_parametrized_type_hint, initargs, obj=obj)


# Tornado support

Expand Down
130 changes: 39 additions & 91 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,9 +1787,6 @@ def g():

self.assertEqual(f2.__doc__, f.__doc__)

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling type annotations isn't supported for py36 and "
"below.")
def test_wraps_preserves_function_annotations(self):
def f(x):
pass
Expand All @@ -1804,79 +1801,7 @@ def g(x):

self.assertEqual(f2.__annotations__, f.__annotations__)

@unittest.skipIf(sys.version_info >= (3, 7),
"pickling annotations is supported starting Python 3.7")
def test_function_annotations_silent_dropping(self):
# Because of limitations of typing module, cloudpickle does not pickle
# the type annotations of a dynamic function or class for Python < 3.7

class UnpicklableAnnotation:
# Mock Annotation metaclass that errors out loudly if we try to
# pickle one of its instances
def __reduce__(self):
raise Exception("not picklable")

unpickleable_annotation = UnpicklableAnnotation()

def f(a: unpickleable_annotation):
return a

with pytest.raises(Exception):
cloudpickle.dumps(f.__annotations__)

depickled_f = pickle_depickle(f, protocol=self.protocol)
assert depickled_f.__annotations__ == {}

@unittest.skipIf(sys.version_info >= (3, 7) or sys.version_info < (3, 6),
"pickling annotations is supported starting Python 3.7")
def test_class_annotations_silent_dropping(self):
# Because of limitations of typing module, cloudpickle does not pickle
# the type annotations of a dynamic function or class for Python < 3.7

# Pickling and unpickling must be done in different processes when
# testing dynamic classes (see #313)

code = '''if 1:
import cloudpickle
import sys
class UnpicklableAnnotation:
# Mock Annotation metaclass that errors out loudly if we try to
# pickle one of its instances
def __reduce__(self):
raise Exception("not picklable")
unpickleable_annotation = UnpicklableAnnotation()
class A:
a: unpickleable_annotation
try:
cloudpickle.dumps(A.__annotations__)
except Exception:
pass
else:
raise AssertionError
sys.stdout.buffer.write(cloudpickle.dumps(A, protocol={protocol}))
'''
cmd = [sys.executable, '-c', code.format(protocol=self.protocol)]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
proc.wait()
out, err = proc.communicate()
assert proc.returncode == 0, err

depickled_a = pickle.loads(out)
assert not hasattr(depickled_a, "__annotations__")

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling type hints isn't supported for py36"
" and below.")
def test_type_hint(self):
# Try to pickle compound typing constructs. This would typically fail
# on Python < 3.7 (See #193)
t = typing.Union[list, int]
assert pickle_depickle(t) == t

Expand Down Expand Up @@ -2142,16 +2067,17 @@ def test_pickle_importable_typevar(self):
from typing import AnyStr
assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling generics not supported below py37")
def test_generic_type(self):
T = typing.TypeVar('T')

class C(typing.Generic[T]):
pass

assert pickle_depickle(C, protocol=self.protocol) is C
assert pickle_depickle(C[int], protocol=self.protocol) is C[int]

# Identity is not part of the typing contract: only test for
# equality instead.
assert pickle_depickle(C[int], protocol=self.protocol) == C[int]

with subprocess_worker(protocol=self.protocol) as worker:

Expand All @@ -2170,33 +2096,55 @@ def check_generic(generic, origin, type_value):
assert check_generic(C[int], C, int) == "ok"
assert worker.run(check_generic, C[int], C, int) == "ok"

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling type hints not supported below py37")
def test_locally_defined_class_with_type_hints(self):
with subprocess_worker(protocol=self.protocol) as worker:
for type_ in _all_types_to_test():
# The type annotation syntax causes a SyntaxError on Python 3.5
code = textwrap.dedent("""\
class MyClass:
attribute: type_
def method(self, arg: type_) -> type_:
return arg
""")
ns = {"type_": type_}
exec(code, ns)
MyClass = ns["MyClass"]
MyClass.__annotations__ = {'attribute': type_}

def check_annotations(obj, expected_type):
assert obj.__annotations__["attribute"] is expected_type
assert obj.method.__annotations__["arg"] is expected_type
assert obj.method.__annotations__["return"] is expected_type
assert obj.__annotations__["attribute"] == expected_type
assert obj.method.__annotations__["arg"] == expected_type
assert (
obj.method.__annotations__["return"] == expected_type
)
return "ok"

obj = MyClass()
assert check_annotations(obj, type_) == "ok"
assert worker.run(check_annotations, obj, type_) == "ok"

def test_generic_extensions(self):
typing_extensions = pytest.importorskip('typing_extensions')

objs = [
typing_extensions.Literal,
typing_extensions.Final,
typing_extensions.Literal['a'],
typing_extensions.Final[int],
]

for obj in objs:
depickled_obj = pickle_depickle(obj, protocol=self.protocol)
assert depickled_obj == obj

def test_class_annotations(self):
class C:
pass
C.__annotations__ = {'a': int}

C1 = pickle_depickle(C, protocol=self.protocol)
assert C1.__annotations__ == C.__annotations__

def test_function_annotations(self):
def f(a: int) -> str:
pass

f1 = pickle_depickle(f, protocol=self.protocol)
assert f1.__annotations__ == f.__annotations__


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down

0 comments on commit 1ba10a5

Please sign in to comment.