Skip to content

Commit

Permalink
Add torch.serialization.safe_globals context manager (#127939)
Browse files Browse the repository at this point in the history
Add context manager mentioned in #127808 (review)

Pull Request resolved: #127939
Approved by: https://github.com/albanD
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Jul 12, 2024
1 parent f0d7164 commit 7c289c2
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/notes/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,4 @@ The following utility functions are related to serialization:
.. autofunction:: add_safe_globals
.. autofunction:: clear_safe_globals
.. autofunction:: get_safe_globals
.. autoclass:: safe_globals
23 changes: 23 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4329,6 +4329,29 @@ def test_safe_globals_for_weights_only(self):
finally:
torch.serialization.clear_safe_globals()

def test_safe_globals_context_manager_weights_only(self):
'''
Tests torch.serialization.safe_globals context manager
'''
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
p = torch.nn.Parameter(t)
sd = OrderedDict([('t', t), ('p', p)])

try:
torch.serialization.add_safe_globals([TestEmptySubclass])
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
with torch.serialization.safe_globals([TwoTensor]):
f.seek(0)
torch.load(f, weights_only=True)
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"):
torch.load(f, weights_only=True)
finally:
torch.serialization.clear_safe_globals()

@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
def test_tensor_subclass_map_location(self):
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
Expand Down
18 changes: 18 additions & 0 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ def _clear_safe_globals():
_marked_safe_globals_list = []


def _remove_safe_globals(globals_to_remove: List[Any]):
global _marked_safe_globals_list
_marked_safe_globals_list = list(
set(_marked_safe_globals_list) - set(globals_to_remove)
)


class _safe_globals:
def __init__(self, safe_globals: List[Any]):
self.safe_globals = safe_globals

def __enter__(self):
_add_safe_globals(self.safe_globals)

def __exit__(self, type, value, tb):
_remove_safe_globals(self.safe_globals)


# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
# For example if user had a script like
# torch.load(file_a)
Expand Down
27 changes: 27 additions & 0 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"clear_safe_globals",
"get_safe_globals",
"add_safe_globals",
"safe_globals",
]


Expand Down Expand Up @@ -230,6 +231,32 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
_weights_only_unpickler._add_safe_globals(safe_globals)


class safe_globals(_weights_only_unpickler._safe_globals):
r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
Args:
safe_globals: List of globals for weights_only load.
Example:
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
>>> import tempfile
>>> class MyTensor(torch.Tensor):
... pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
... torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
... with torch.serialization.safe_globals([MyTensor]):
... torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
# [-0.8234, 2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
"""
pass


def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
Expand Down

0 comments on commit 7c289c2

Please sign in to comment.