From 0fa524a75f9f5e3ef1ead84e4fe240aa728d3719 Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Sat, 9 Jul 2022 23:30:37 -0300 Subject: [PATCH] Fix load_session() and restrict loading a session in a different module (#507) * Don't update vars(main) twice * Inspect the pickle beginnig to identify main and check against 'main' argument * Save and restore modules created at runtime with ModuleType() * tests: don't need to add runtime module to sys.modules * load_session_copy(): load a session state into a runtime module * tests: session tests code reorganization * tests: test runtime created module session saving * tests: test load_session_copy * review: adjustments * small fixes * use __dict__ * naming changes * review: final renaming and adjustments --- dill/__init__.py | 12 +- dill/_dill.py | 412 ++++++++++++++++++++++++++++++++++--- dill/tests/test_session.py | 325 +++++++++++++++-------------- docs/source/dill.rst | 2 +- 4 files changed, 558 insertions(+), 193 deletions(-) diff --git a/dill/__init__.py b/dill/__init__.py index fd4fb468..028112dc 100644 --- a/dill/__init__.py +++ b/dill/__init__.py @@ -23,11 +23,13 @@ del os, sys, parent, get_license_text, get_readme_as_rst -from ._dill import dump, dumps, load, loads, dump_session, load_session, \ - Pickler, Unpickler, register, copy, pickle, pickles, check, \ - HIGHEST_PROTOCOL, DEFAULT_PROTOCOL, PicklingError, UnpicklingError, \ - HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, PickleError, PickleWarning, \ - PicklingWarning, UnpicklingWarning +from ._dill import ( + dump, dumps, load, loads, dump_module, load_module, load_module_asdict, + dump_session, load_session, Pickler, Unpickler, register, copy, pickle, + pickles, check, HIGHEST_PROTOCOL, DEFAULT_PROTOCOL, PicklingError, + UnpicklingError, HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, PickleError, + PickleWarning, PicklingWarning, UnpicklingWarning, +) from . import source, temp, detect # get global settings diff --git a/dill/_dill.py b/dill/_dill.py index 96385b33..80293399 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -15,11 +15,14 @@ Test against "all" python types (Std. Lib. CH 1-15 @ 2.7) by mmckerns. Test against CH16+ Std. Lib. ... TBD. """ -__all__ = ['dump','dumps','load','loads','dump_session','load_session', - 'Pickler','Unpickler','register','copy','pickle','pickles', - 'check','HIGHEST_PROTOCOL','DEFAULT_PROTOCOL','PicklingError', - 'UnpicklingError','HANDLE_FMODE','CONTENTS_FMODE','FILE_FMODE', - 'PickleError','PickleWarning','PicklingWarning','UnpicklingWarning'] +__all__ = [ + 'dump', 'dumps', 'load', 'loads', 'dump_module', 'load_module', + 'load_module_asdict', 'dump_session', 'load_session', 'Pickler', 'Unpickler', + 'register', 'copy', 'pickle', 'pickles', 'check', 'HIGHEST_PROTOCOL', + 'DEFAULT_PROTOCOL', 'PicklingError', 'UnpicklingError', 'HANDLE_FMODE', + 'CONTENTS_FMODE', 'FILE_FMODE', 'PickleError', 'PickleWarning', + 'PicklingWarning', 'UnpicklingWarning' +] __module__ = 'dill' @@ -27,6 +30,8 @@ from .logger import adapter as logger from .logger import trace as _trace +from typing import Optional, Union + import os import sys diff = None @@ -315,8 +320,12 @@ def loads(str, ignore=None, **kwds): ### End: Shorthands ### ### Pickle the Interpreter Session +import pathlib +import tempfile + SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, FunctionType, MethodType, BuiltinMethodType) +TEMPDIR = pathlib.PurePath(tempfile.gettempdir()) def _module_map(): """get map of imported modules""" @@ -396,20 +405,80 @@ def _restore_modules(unpickler, main_module): pass #NOTE: 06/03/15 renamed main_module to main -def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): - """pickle the current state of __main__ to a file""" +def dump_module( + filename = str(TEMPDIR/'session.pkl'), + main: Optional[Union[ModuleType, str]] = None, + refimported: bool = False, + **kwds +) -> None: + """Pickle the current state of :py:mod:`__main__` or another module to a file. + + Save the interpreter session (the contents of the built-in module + :py:mod:`__main__`) or the state of another module to a pickle file. This + can then be restored by calling the function :py:func:`load_module`. + + Runtime-created modules, like the ones constructed by + :py:class:`~types.ModuleType`, can also be saved and restored thereafter. + + Parameters: + filename: a path-like object or a writable stream. + main: a module object or an importable module name. + refimported: if `True`, all imported objects in the module's namespace + are saved by reference. *Note:* this is different from the ``byref`` + option of other "dump" functions and is not affected by + ``settings['byref']``. + **kwds: extra keyword arguments passed to :py:class:`Pickler()`. + + Raises: + :py:exc:`PicklingError`: if pickling fails. + + Examples: + - Save current session state: + + >>> import dill + >>> dill.dump_module() # save state of __main__ to /tmp/session.pkl + + - Save the state of an imported/importable module: + + >>> import my_mod as m + >>> m.var = 'new value' + >>> dill.dump_module('my_mod_session.pkl', main='my_mod') + + - Save the state of a non-importable, runtime-created module: + + >>> from types import ModuleType + >>> runtime = ModuleType('runtime') + >>> runtime.food = ['bacon', 'eggs', 'spam'] + >>> runtime.process_food = m.process_food + >>> dill.dump_module('runtime_session.pkl', main=runtime, refimported=True) + + *Changed in version 0.3.6:* the function ``dump_session()`` was renamed to + ``dump_module()``. + + *Changed in version 0.3.6:* the parameter ``byref`` was renamed to + ``refimported``. + """ + if 'byref' in kwds: + warnings.warn( + "The parameter 'byref' was renamed to 'refimported', use this" + " instead. Note: the underlying dill.Pickler do accept a 'byref'" + " argument, but it has no effect on session saving.", + PendingDeprecationWarning + ) + if refimported: + raise ValueError("both 'refimported' and 'byref' arguments were used.") + refimported = kwds.pop('byref') from .settings import settings protocol = settings['protocol'] if main is None: main = _main_module if hasattr(filename, 'write'): - f = filename + file = filename else: - f = open(filename, 'wb') + file = open(filename, 'wb') try: - pickler = Pickler(f, protocol, **kwds) - pickler._file = f + pickler = Pickler(file, protocol, **kwds) pickler._original_main = main - if byref: + if refimported: main = _stash_modules(main) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference @@ -419,29 +488,304 @@ def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): pickler._main_modified = main is not pickler._original_main pickler.dump(main) finally: - if f is not filename: # If newly opened file - f.close() + if file is not filename: # if newly opened file + file.close() return -def load_session(filename='/tmp/session.pkl', main=None, **kwds): - """update the __main__ module with the state from the session file""" - if main is None: main = _main_module +# Backward compatibility. +def dump_session(filename=str(TEMPDIR/'session.pkl'), main=None, byref=False, **kwds): + warnings.warn("dump_session() was renamed to dump_module()", PendingDeprecationWarning) + dump_module(filename, main, refimported=byref, **kwds) +dump_session.__doc__ = dump_module.__doc__ + +class _PeekableReader: + """lightweight stream wrapper that implements peek()""" + def __init__(self, stream): + self.stream = stream + def read(self, n): + return self.stream.read(n) + def readline(self): + return self.stream.readline() + def tell(self): + return self.stream.tell() + def close(self): + return self.stream.close() + def peek(self, n): + stream = self.stream + try: + if hasattr(stream, 'flush'): stream.flush() + position = stream.tell() + stream.seek(position) # assert seek() works before reading + chunk = stream.read(n) + stream.seek(position) + return chunk + except (AttributeError, OSError): + raise NotImplementedError("stream is not peekable: %r", stream) from None + +def _make_peekable(stream): + """return stream as an object with a peek() method""" + import io + if hasattr(stream, 'peek'): + return stream + if not (hasattr(stream, 'tell') and hasattr(stream, 'seek')): + try: + return io.BufferedReader(stream) + except Exception: + pass + return _PeekableReader(stream) + +def _identify_module(file, main=None): + """identify the session file's module name""" + from pickletools import genops + UNICODE = {'UNICODE', 'BINUNICODE', 'SHORT_BINUNICODE'} + found_import = False + try: + for opcode, arg, pos in genops(file.peek(256)): + if not found_import: + if opcode.name in ('GLOBAL', 'SHORT_BINUNICODE') and \ + arg.endswith('_import_module'): + found_import = True + else: + if opcode.name in UNICODE: + return arg + else: + raise UnpicklingError("reached STOP without finding main module") + except (NotImplementedError, ValueError) as error: + # ValueError occours when the end of the chunk is reached (without a STOP). + if isinstance(error, NotImplementedError) and main is not None: + # file is not peekable, but we have main. + return None + raise UnpicklingError("unable to identify main module") from error + +def load_module( + filename = str(TEMPDIR/'session.pkl'), + main: Union[ModuleType, str] = None, + **kwds +) -> Optional[ModuleType]: + """Update :py:mod:`__main__` or another module with the state from the + session file. + + Restore the interpreter session (the built-in module :py:mod:`__main__`) or + the state of another module from a pickle file created by the function + :py:func:`dump_module`. + + If loading the state of a (non-importable) runtime-created module, a version + of this module may be passed as the argument ``main``. Otherwise, a new + module object is created with :py:class:`~types.ModuleType` and returned + after it's updated. + + Parameters: + filename: a path-like object or a readable stream. + main: an importable module name or a module object. + **kwds: extra keyword arguments passed to :py:class:`Unpickler()`. + + Raises: + :py:exc:`UnpicklingError`: if unpickling fails. + :py:exc:`ValueError`: if the ``main`` argument and the session file's + module are incompatible. + + Returns: + The restored module if it's different from :py:mod:`__main__` and + wasn't passed as the ``main`` argument. + + Examples: + - Load a saved session state: + + >>> import dill, sys + >>> dill.load_module() # updates __main__ from /tmp/session.pkl + >>> restored_var + 'this variable was created/updated by load_module()' + + - Load the saved state of an importable module: + + >>> my_mod = dill.load_module('my_mod_session.pkl') + >>> my_mod.var + 'new value' + >>> my_mod in sys.modules.values() + True + + - Load the saved state of a non-importable, runtime-created module: + + >>> runtime = dill.load_module('runtime_session.pkl') + >>> runtime.process_food is my_mod.process_food # was saved by reference + True + >>> runtime in sys.modules.values() + False + + - Update the state of a non-importable, runtime-created module: + + >>> from types import ModuleType + >>> runtime = ModuleType('runtime') + >>> runtime.food = ['pizza', 'burger'] + >>> dill.load_module('runtime_session.pkl', main=runtime) + >>> runtime.food + ['bacon', 'eggs', 'spam'] + + *Changed in version 0.3.6:* the function ``load_session()`` was renamed to + ``load_module()``. + + See also: + :py:func:`load_module_asdict` to load the contents of a saved session + (from :py:mod:`__main__` or any importable module) into a dictionary. + """ + main_arg = main if hasattr(filename, 'read'): - f = filename + file = filename else: - f = open(filename, 'rb') - try: #FIXME: dill.settings are disabled - unpickler = Unpickler(f, **kwds) - unpickler._main = main + file = open(filename, 'rb') + try: + file = _make_peekable(file) + #FIXME: dill.settings are disabled + unpickler = Unpickler(file, **kwds) unpickler._session = True + pickle_main = _identify_module(file, main) + + # Resolve unpickler._main + if main is None and pickle_main is not None: + main = pickle_main + if isinstance(main, str): + if main.startswith('__runtime__.'): + # Create runtime module to load the session into. + main = ModuleType(main.partition('.')[-1]) + else: + main = _import_module(main) + if main is not None: + if not isinstance(main, ModuleType): + raise ValueError("%r is not a module" % main) + unpickler._main = main + else: + main = unpickler._main + + # Check against the pickle's main. + is_main_imported = _is_imported_module(main) + if pickle_main is not None: + is_runtime_mod = pickle_main.startswith('__runtime__.') + if is_runtime_mod: + pickle_main = pickle_main.partition('.')[-1] + if is_runtime_mod and is_main_imported: + raise ValueError( + "can't restore non-imported module %r into an imported one" + % pickle_main + ) + if not is_runtime_mod and not is_main_imported: + raise ValueError( + "can't restore imported module %r into a non-imported one" + % pickle_main + ) + if main.__name__ != pickle_main: + raise ValueError( + "can't restore module %r into module %r" + % (pickle_main, main.__name__) + ) + + # This is for find_class() to be able to locate it. + if not is_main_imported: + runtime_main = '__runtime__.%s' % main.__name__ + sys.modules[runtime_main] = main + module = unpickler.load() - unpickler._session = False - main.__dict__.update(module.__dict__) - _restore_modules(unpickler, main) finally: - if f is not filename: # If newly opened file - f.close() - return + if not hasattr(filename, 'read'): # if newly opened file + file.close() + try: + del sys.modules[runtime_main] + except (KeyError, NameError): + pass + assert module is main + _restore_modules(unpickler, module) + if not (module is _main_module or module is main_arg): + return module + +# Backward compatibility. +def load_session(filename=str(TEMPDIR/'session.pkl'), main=None, **kwds): + warnings.warn("load_session() was renamed to load_module().", PendingDeprecationWarning) + load_module(filename, main, **kwds) +load_session.__doc__ = load_module.__doc__ + +def load_module_asdict( + filename = str(TEMPDIR/'session.pkl'), + update: bool = False, + **kwds +) -> dict: + """ + Load the contents of a module from a session file into a dictionary. + + ``load_module_asdict()`` does the equivalent of this function:: + + lambda filename: vars(load_module(filename)).copy() + + but without changing the original module. + + The loaded module's origin is stored in the ``__session__`` attribute. + + Parameters: + filename: a path-like object or a readable stream + update: if `True`, the dictionary is updated with the current state of + the module before loading variables from the session file + **kwds: extra keyword arguments passed to :py:class:`Unpickler()` + + Raises: + :py:exc:`UnpicklingError`: if unpickling fails + + Returns: + A copy of the restored module's dictionary. + + Note: + If the ``update`` option is used, the original module will be loaded if + it wasn't yet. + + Example: + >>> import dill + >>> alist = [1, 2, 3] + >>> anum = 42 + >>> dill.dump_module() + >>> anum = 0 + >>> new_var = 'spam' + >>> main_vars = dill.load_module_asdict() + >>> main_vars['__name__'], main_vars['__session__'] + ('__main__', '/tmp/session.pkl') + >>> main_vars is globals() # loaded objects don't reference current global variables + False + >>> main_vars['alist'] == alist + True + >>> main_vars['alist'] is alist # was saved by value + False + >>> main_vars['anum'] == anum # changed after the session was saved + False + >>> new_var in main_vars # would be True if the option 'update' was set + False + """ + if 'main' in kwds: + raise TypeError("'main' is an invalid keyword argument for load_module_asdict()") + if hasattr(filename, 'read'): + file = filename + else: + file = open(filename, 'rb') + try: + file = _make_peekable(file) + main_name = _identify_module(file) + old_main = sys.modules.get(main_name) + main = ModuleType(main_name) + if update: + if old_main is None: + old_main = _import_module(main_name) + main.__dict__.update(old_main.__dict__) + else: + main.__builtins__ = __builtin__ + sys.modules[main_name] = main + load_module(file, **kwds) + main.__session__ = str(filename) + finally: + if not hasattr(filename, 'read'): # if newly opened file + file.close() + try: + if old_main is None: + del sys.modules[main_name] + else: + sys.modules[main_name] = old_main + except NameError: # failed before setting old_main + pass + return main.__dict__ ### End: Pickle the Interpreter @@ -1132,14 +1476,16 @@ def _dict_from_dictproxy(dictproxy): def _import_module(import_name, safe=False): try: - if '.' in import_name: + if import_name.startswith('__runtime__.'): + return sys.modules[import_name] + elif '.' in import_name: items = import_name.split('.') module = '.'.join(items[:-1]) obj = items[-1] else: return __import__(import_name) return getattr(__import__(module, None, None, [obj]), obj) - except (ImportError, AttributeError): + except (ImportError, AttributeError, KeyError): if safe: return None raise @@ -1719,6 +2065,9 @@ def _is_builtin_module(module): module.__file__.endswith(EXTENSION_SUFFIXES) or \ 'site-packages' in module.__file__ +def _is_imported_module(module): + return getattr(module, '__loader__', None) is not None or module in sys.modules.values() + @register(ModuleType) def save_module(pickler, obj): if False: #_use_diff: @@ -1746,7 +2095,8 @@ def save_module(pickler, obj): _main_dict = obj.__dict__.copy() #XXX: better no copy? option to copy? [_main_dict.pop(item, None) for item in singletontypes + ["__builtins__", "__loader__"]] - pickler.save_reduce(_import_module, (obj.__name__,), obj=obj, + mod_name = obj.__name__ if _is_imported_module(obj) else '__runtime__.%s' % obj.__name__ + pickler.save_reduce(_import_module, (mod_name,), obj=obj, state=_main_dict) logger.trace(pickler, "# M1") elif obj.__name__ == "dill._dill": diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 689cc975..8f687934 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -9,60 +9,57 @@ import os import sys import __main__ +from io import BytesIO import dill -session_file = os.path.join(os.path.dirname(__file__), 'session-byref-%s.pkl') +session_file = os.path.join(os.path.dirname(__file__), 'session-refimported-%s.pkl') -def test_modules(main, byref): - main_dict = main.__dict__ +################### +# Child process # +################### - try: - for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): - assert main_dict[obj].__name__ in sys.modules - - for obj in ('Calendar', 'isleap'): - assert main_dict[obj] is sys.modules['calendar'].__dict__[obj] - assert main.day_name.__module__ == 'calendar' - if byref: - assert main.day_name is sys.modules['calendar'].__dict__['day_name'] - - assert main.complex_log is sys.modules['cmath'].__dict__['log'] +def _error_line(error, obj, refimported): + import traceback + line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') + return "while testing (with refimported=%s): %s" % (refimported, line.lstrip()) - except AssertionError: - import traceback - error_line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') - print("Error while testing (byref=%s):" % byref, error_line, sep="\n", file=sys.stderr) - raise - - -# Test session loading in a fresh interpreter session. if __name__ == '__main__' and len(sys.argv) >= 3 and sys.argv[1] == '--child': - byref = sys.argv[2] == 'True' - dill.load_session(session_file % byref) - test_modules(__main__, byref) - sys.exit() + # Test session loading in a fresh interpreter session. + refimported = (sys.argv[2] == 'True') + dill.load_module(session_file % refimported) + + def test_modules(refimported): + # FIXME: In this test setting with CPython 3.7, 'calendar' is not included + # in sys.modules, independent of the value of refimported. Tried to + # run garbage collection just before loading the session with no luck. It + # fails even when preceding them with 'import calendar'. Needed to run + # these kinds of tests in a supbrocess. Failing test sample: + # assert globals()['day_name'] is sys.modules['calendar'].__dict__['day_name'] + try: + for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): + assert globals()[obj].__name__ in sys.modules + assert 'calendar' in sys.modules and 'cmath' in sys.modules + import calendar, cmath -del test_modules + for obj in ('Calendar', 'isleap'): + assert globals()[obj] is sys.modules['calendar'].__dict__[obj] + assert __main__.day_name.__module__ == 'calendar' + if refimported: + assert __main__.day_name is calendar.day_name + assert __main__.complex_log is cmath.log -def _clean_up_cache(module): - cached = module.__file__.split('.', 1)[0] + '.pyc' - cached = module.__cached__ if hasattr(module, '__cached__') else cached - pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__') - for remove, file in [(os.remove, cached), (os.removedirs, pycache)]: - try: - remove(file) - except OSError: - pass + except AssertionError as error: + error.args = (_error_line(error, obj, refimported),) + raise + test_modules(refimported) + sys.exit() -# To clean up namespace before loading the session. -original_modules = set(sys.modules.keys()) - \ - set(['json', 'urllib', 'xml.sax', 'xml.dom.minidom', 'calendar', 'cmath']) -original_objects = set(__main__.__dict__.keys()) -original_objects.add('original_objects') - +#################### +# Parent process # +#################### # Create various kinds of objects to test different internal logics. @@ -72,7 +69,6 @@ def _clean_up_cache(module): from xml import sax # submodule import xml.dom.minidom as dom # submodule under alias import test_dictviews as local_mod # non-builtin top-level module -atexit.register(_clean_up_cache, local_mod) ## Imported objects. from calendar import Calendar, isleap, day_name # class, function, other object @@ -95,152 +91,169 @@ def weekdays(self): cal = CalendarSubclass() selfref = __main__ +# Setup global namespace for session saving tests. +class TestNamespace: + test_globals = globals().copy() + def __init__(self, **extra): + self.extra = extra + def __enter__(self): + self.backup = globals().copy() + globals().clear() + globals().update(self.test_globals) + globals().update(self.extra) + return self + def __exit__(self, *exc_info): + globals().clear() + globals().update(self.backup) -def test_objects(main, copy_dict, byref): - main_dict = main.__dict__ +def _clean_up_cache(module): + cached = module.__file__.split('.', 1)[0] + '.pyc' + cached = module.__cached__ if hasattr(module, '__cached__') else cached + pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__') + for remove, file in [(os.remove, cached), (os.removedirs, pycache)]: + try: + remove(file) + except OSError: + pass - try: - for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): - assert main_dict[obj].__name__ == copy_dict[obj].__name__ +atexit.register(_clean_up_cache, local_mod) - #FIXME: In the second test call, 'calendar' is not included in - # sys.modules, independent of the value of byref. Tried to run garbage - # collection before with no luck. This block fails even with - # "import calendar" before it. Needed to restore the original modules - # with the 'copy_modules' object. (Moved to "test_session_{1,2}.py".) +def _test_objects(main, globals_copy, refimported): + try: + main_dict = __main__.__dict__ + global Person, person, Calendar, CalendarSubclass, cal, selfref - #for obj in ('Calendar', 'isleap'): - # assert main_dict[obj] is sys.modules['calendar'].__dict__[obj] - #assert main_dict['day_name'].__module__ == 'calendar' - #if byref: - # assert main_dict['day_name'] is sys.modules['calendar'].__dict__['day_name'] + for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): + assert globals()[obj].__name__ == globals_copy[obj].__name__ for obj in ('x', 'empty', 'names'): - assert main_dict[obj] == copy_dict[obj] + assert main_dict[obj] == globals_copy[obj] for obj in ['squared', 'cubed']: assert main_dict[obj].__globals__ is main_dict - assert main_dict[obj](3) == copy_dict[obj](3) + assert main_dict[obj](3) == globals_copy[obj](3) - assert main.Person.__module__ == main.__name__ - assert isinstance(main.person, main.Person) - assert main.person.age == copy_dict['person'].age + assert Person.__module__ == __main__.__name__ + assert isinstance(person, Person) + assert person.age == globals_copy['person'].age - assert issubclass(main.CalendarSubclass, main.Calendar) - assert isinstance(main.cal, main.CalendarSubclass) - assert main.cal.weekdays() == copy_dict['cal'].weekdays() + assert issubclass(CalendarSubclass, Calendar) + assert isinstance(cal, CalendarSubclass) + assert cal.weekdays() == globals_copy['cal'].weekdays() - assert main.selfref is main + assert selfref is __main__ - except AssertionError: - import traceback - error_line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') - print("Error while testing (byref=%s):" % byref, error_line, sep="\n", file=sys.stderr) + except AssertionError as error: + error.args = (_error_line(error, obj, refimported),) raise +def test_session_main(refimported): + """test dump/load_module() for __main__, both in this process and in a subprocess""" + extra_objects = {} + if refimported: + # Test unpickleable imported object in main. + from sys import flags + extra_objects['flags'] = flags -if __name__ == '__main__': - - # Test dump_session() and load_session(). - for byref in (False, True): - if byref: - # Test unpickleable imported object in main. - from sys import flags - - #print(sorted(set(sys.modules.keys()) - original_modules)) - dill._test_file = dill._dill.StringIO() + with TestNamespace(**extra_objects) as ns: try: - # For the subprocess. - dill.dump_session(session_file % byref, byref=byref) - - dill.dump_session(dill._test_file, byref=byref) - dump = dill._test_file.getvalue() - dill._test_file.close() - - import __main__ - copy_dict = __main__.__dict__.copy() - copy_modules = sys.modules.copy() - del copy_dict['dump'] - del copy_dict['__main__'] - for name in copy_dict.keys(): - if name not in original_objects: - del __main__.__dict__[name] - for module in list(sys.modules.keys()): - if module not in original_modules: - del sys.modules[module] - - dill._test_file = dill._dill.StringIO(dump) - dill.load_session(dill._test_file) - #print(sorted(set(sys.modules.keys()) - original_modules)) - # Test session loading in a new session. + dill.dump_module(session_file % refimported, refimported=refimported) from dill.tests.__main__ import python, shell, sp - error = sp.call([python, __file__, '--child', str(byref)], shell=shell) + error = sp.call([python, __file__, '--child', str(refimported)], shell=shell) if error: sys.exit(error) - del python, shell, sp - finally: - dill._test_file.close() try: - os.remove(session_file % byref) + os.remove(session_file % refimported) except OSError: pass - test_objects(__main__, copy_dict, byref) - __main__.__dict__.update(copy_dict) - sys.modules.update(copy_modules) - del __main__, copy_dict, copy_modules, dump - - - # This is for code coverage, tests the use case of dump_session(byref=True) - # without imported objects in the namespace. It's a contrived example because - # even dill can't be in it. - from types import ModuleType - modname = '__test_main__' - main = ModuleType(modname) - main.x = 42 - - _main = dill._dill._stash_modules(main) - if _main is not main: - print("There are objects to save by referenece that shouldn't be:", - _main.__dill_imported, _main.__dill_imported_as, _main.__dill_imported_top_level, - file=sys.stderr) + # Test session loading in the same session. + session_buffer = BytesIO() + dill.dump_module(session_buffer, refimported=refimported) + session_buffer.seek(0) + dill.load_module(session_buffer) + ns.backup['_test_objects'](__main__, ns.backup, refimported) - test_file = dill._dill.StringIO() - try: - dill.dump_session(test_file, main=main, byref=True) - dump = test_file.getvalue() - test_file.close() - - main = sys.modules[modname] = ModuleType(modname) # empty - # This should work after fixing https://github.com/uqfoundation/dill/issues/462 - test_file = dill._dill.StringIO(dump) - dill.load_session(test_file, main=main) - finally: - test_file.close() - - assert main.x == 42 - - - # Dump session for module that is not __main__: +def test_session_other(): + """test dump/load_module() for a module other than __main__""" import test_classdef as module atexit.register(_clean_up_cache, module) module.selfref = module dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')] - test_file = dill._dill.StringIO() - try: - dill.dump_session(test_file, main=module) - dump = test_file.getvalue() - test_file.close() + session_buffer = BytesIO() + dill.dump_module(session_buffer, main=module) - for obj in dict_objects: - del module.__dict__[obj] + for obj in dict_objects: + del module.__dict__[obj] - test_file = dill._dill.StringIO(dump) - dill.load_session(test_file, main=module) - finally: - test_file.close() + session_buffer.seek(0) + dill.load_module(session_buffer) #, main=module) assert all(obj in module.__dict__ for obj in dict_objects) assert module.selfref is module + +def test_runtime_module(): + from types import ModuleType + modname = '__runtime__' + runtime = ModuleType(modname) + runtime.x = 42 + + mod = dill._dill._stash_modules(runtime) + if mod is not runtime: + print("There are objects to save by referenece that shouldn't be:", + mod.__dill_imported, mod.__dill_imported_as, mod.__dill_imported_top_level, + file=sys.stderr) + + # This is also for code coverage, tests the use case of dump_module(refimported=True) + # without imported objects in the namespace. It's a contrived example because + # even dill can't be in it. This should work after fixing #462. + session_buffer = BytesIO() + dill.dump_module(session_buffer, main=runtime, refimported=True) + session_dump = session_buffer.getvalue() + + # Pass a new runtime created module with the same name. + runtime = ModuleType(modname) # empty + return_val = dill.load_module(BytesIO(session_dump), main=runtime) + assert return_val is None + assert runtime.__name__ == modname + assert runtime.x == 42 + assert runtime not in sys.modules.values() + + # Pass nothing as main. load_module() must create it. + session_buffer.seek(0) + runtime = dill.load_module(BytesIO(session_dump)) + assert runtime.__name__ == modname + assert runtime.x == 42 + assert runtime not in sys.modules.values() + +def test_load_module_asdict(): + with TestNamespace(): + session_buffer = BytesIO() + dill.dump_module(session_buffer) + + global empty, names, x, y + x = y = 0 # change x and create y + del empty + globals_state = globals().copy() + + session_buffer.seek(0) + main_vars = dill.load_module_asdict(session_buffer) + + assert main_vars is not globals() + assert globals() == globals_state + + assert main_vars['__name__'] == '__main__' + assert main_vars['names'] == names + assert main_vars['names'] is not names + assert main_vars['x'] != x + assert 'y' not in main_vars + assert 'empty' in main_vars + +if __name__ == '__main__': + test_session_main(refimported=False) + test_session_main(refimported=True) + test_session_other() + test_runtime_module() + test_load_module_asdict() diff --git a/docs/source/dill.rst b/docs/source/dill.rst index 9061863c..31d41c91 100644 --- a/docs/source/dill.rst +++ b/docs/source/dill.rst @@ -11,7 +11,7 @@ dill module :special-members: :show-inheritance: :imported-members: -.. :exclude-members: + :exclude-members: dump_session, load_session detect module -------------