Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warning from torch.load starting in torch 2.4 #1064

Merged
merged 5 commits into from
Sep 19, 2024

Conversation

BenjaminBossan
Copy link
Collaborator

See discussion in #1063

Starting from PyTorch 2.4, there is a warning when torch.load is called without setting the weights_only argument. This is because in the future, the default will switch from False to True, which can result in a lot of errors when trying to load torch files (which are pickle files and thus insecure).

In this PR, we add a possibility for the user to influence the kwargs passed to torch.load so that they can control that behavior. If not further indicated by the user, we will use the same defaults as the installed torch version. Therefore, users will only encounter this issue via skorch if they would have encountered it via torch anyway.

Since it's not 100% certain if the default will switch in torch 2.6.0, we may have to adjust the version check in the future.

Besides directly testing the kwargs being passed on, a test was also added that net.load_params does not give any warnings. This is already indirectly tested through some accelerate tests that are currently failing with torch 2.4, but it's better to have an explicit test.

After this is merged, the CI should pass when using torch 2.4.0.

See discussion in #1063

Starting from PyTorch 2.4, there is a warning when torch.load is called
without setting the weights_only argument. This is because in the
future, the default will switch from False to True, which can result in
a lot of errors when trying to load torch files (which are pickle files
and thus insecure).

In this PR, we add a possibility for the user to influence the kwargs
passed to torch.load so that they can control that behavior. If not
further indicated by the user, we will use the same defaults as the
installed torch version. Therefore, users will only encounter this issue
via skorch if they would have encountered it via torch anyway.

Since it's not 100% certain if the default will switch in torch 2.6.0,
we may have to adjust the version check in the future.

Besides directly testing the kwargs being passed on, a test was also
added that net.load_params does not give any warnings. This is already
indirectly tested through some accelerate tests that are currently
failing with torch 2.4, but it's better to have an explicit test.

After this is merged, the CI should pass when using torch 2.4.0.
skorch/utils.py Outdated
@@ -768,3 +769,16 @@ def _check_f_arguments(caller_name, **kwargs):
key = 'module_' if key == 'f_params' else key[2:] + '_'
kwargs_module[key] = val
return kwargs_module, kwargs_other


def check_torch_weights_only_default_true():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how specific this function is to torch.load, can this return torch_load_kwargs itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I made the suggested change.

skorch/utils.py Outdated


def get_torch_load_kwargs():
"""Returns the kwargs passed to torch.load the correspond to the current
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Returns the kwargs passed to torch.load the correspond to the current
"""Returns the kwargs passed to torch.load that correspond to the current

skorch/utils.py Outdated
@@ -768,3 +769,18 @@ def _check_f_arguments(caller_name, **kwargs):
key = 'module_' if key == 'f_params' else key[2:] + '_'
kwargs_module[key] = val
return kwargs_module, kwargs_other


def get_torch_load_kwargs():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_torch_load_kwargs():
def get_default_torch_load_kwargs():

skorch/net.py Outdated
@@ -2620,10 +2650,14 @@ def _get_state_dict(f_name):

return state_dict
else:
torch_load_kwargs = self.torch_load_kwargs
if torch_load_kwargs is None:
torch_load_kwargs = get_torch_load_kwargs()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch_load_kwargs = get_torch_load_kwargs()
torch_load_kwargs = get_default_torch_load_kwargs()

Instead, rely on the installed torch version and skip if it doesn't fit.
@BenjaminBossan
Copy link
Collaborator Author

CI is failing for unrelated reasons since the latest accelerate release, I opened an issue about it:

huggingface/accelerate#3070

@byi8220
Copy link

byi8220 commented Sep 3, 2024

Quick question about the (unrelated) failing CI, are the CI and integration tests run on multigpu environments at all?

@BenjaminBossan
Copy link
Collaborator Author

Quick question about the (unrelated) failing CI, are the CI and integration tests run on multigpu environments at all?

No, we're only using the free runners from GitHub on this repo. Is there anything that we should check specifically on GPU?

@byi8220
Copy link

byi8220 commented Sep 4, 2024

Is there anything that we should check specifically on GPU?

Not sure. I think the only way GPU training would affect pickling is on distributed setups. I'm actually not sure how reliable pickling a running distributed accelerator is (e.g. there are a LOT of stackoverflow or forum posts about running into issues with pickling generators or in a multiprocessing context)

@BenjaminBossan
Copy link
Collaborator Author

If such a setting causes trouble, it's probably not just because of accelerator, so I think we can disregard that for now.

@BenjaminBossan
Copy link
Collaborator Author

@ottonemo have your points been addressed?

@ottonemo ottonemo merged commit e724424 into master Sep 19, 2024
15 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-torch-load-warning-weights-only branch September 19, 2024 13:52
BenjaminBossan added a commit that referenced this pull request Jan 27, 2025
Resolves #1090.

This one got more complicated than I initially thought. So here it goes:

PyTorch plans to make the switch to weights_only=True for torch.load. We
already partly dealt with that in #1064 when it comes to
save_params/load_params. However, we still had a gap. Namely, when using
pickle directly, i.e. when going through __getstate__ and __setstate__,
we are still using torch.load and torch.save without handling
weights_only. This will cause trouble in the future when the default is
switched. But it's also annoying right now, because users will get the
FutureWarning about weights_only, even if they correctly pass
torch_load_kwargs (see #1090).

The reason why we use torch.save/torch.load for pickle is that those
functions are basically _extended_ pickle functions that have the
benefit of supporting the map_location argument to handle the device of
torch tensors, which we don't have for pickle. The map_location argument
is important, e.g. when saving a net that uses CUDA and loading it on a
machine without CUDA, we would otherwise run into an error.

However, with the move to weights_only=True, these torch.save/torch.load
will become _reduced_ pickle functions, as they will only support a
small subset of objects by default. Therefore, we wouldn't be able to
rely on torch.save/torch.load for pickling the whole skorch object.

In this PR, we thus move to using plain pickle for this. However, now we
run into the issue of how to handle the map_location. The solution I
ended up with is now to intercept torch's _load_from_bytes using a
custom Unpickler, and to specifically use torch.load there. That way, we
can pass the map_location and other torch_load_kwargs. The remaining
unpickling process just works as normal.

Yes, this is a private function, so we cannot be sure if it'll work
indefinitely, If there is a better suggestion, I'm open to it. However,
the function has existed for 7 years, so it's not very likely that it
will change anytime soon:

https://github.com/pytorch/pytorch/blame/0674ab7e33c3f627ca6781ce98468ec1dd4743a5/torch/storage.py#L525

A drawback of the solution is that we cannot just load old skorch nets
that were saved with torch.save using pickle.load. This is because torch
uses custom persistent_load functions. When trying to load with pickle,
we thus get:

_pickle.UnpicklingError: A load persistent id instruction was encountered, but no persistent_load function was specified.

Therefore, I had to keep torch.load as a fallback to avoid backwards
incompatibility. The bad news is that the initial problem persists,
namely that even when passing torch_load_kwargs, users get the
FutureWarning about weights_only. The good news is that users can just
re-save their net with the new skorch version and from then on they
won't see the warning again.

Note that I didn't add a specific test for this problem of loading
backwards nets from before the change, because test_pickle_load, which
uses a checked in pickle file, already covers this.

Other considered solutions:

1. Why not continue using torch.save/torch.load and just pass the
torch_load_kwargs argument to it? This is unforunately not that easy.
When switching to weights_only=True, torch will refuse to load any
custom objects, e.g. class MyModule. There is a way to prevent that,
namely via torch.serialization.add_safe_globals, but it is a ton of work
to add all required objects there, as even builtin Python types are
mostly not supported.
2. We cannot use with torch.device, as this is not honored during
unpickling.
3. During __getstate__, we could recursively go through the state, pop
all torch tensors, and replace them with, say, numpy arrays and
additional meta data like the device, then use this info to restore
those objects during __setstate__. Even though this looks like a cleaner
solution, it is much more complex and therefore, I'd argue more error
prone.

Notes

While working on this, I thought that we could most likely remove the
cuda_dependent_attributes_ (which contains the net.module_,
net.optimizer_, etc.). Their purpose was to call torch.load on these
attributes specifically, but with the new Unpickler, it should also work
without this. However, I kept the attribute for now, mainly for these
reasons:

1. I didn't want to change more than necessary, as these changes are
delicate and I don't to break any existing skorch code or pickle files.
2. The attribute itself is public, so in theory, users may rely on its
existence (not sure if in practice). We would thus have to keep most of
the code related to this attribute.

But LMK if you think we should deprecate and eventually remove this
attribute.
BenjaminBossan added a commit that referenced this pull request Jan 27, 2025
Resolves #1090.

PyTorch plans to make the switch to weights_only=True for torch.load. We
already partly dealt with that in #1064 when it comes to
save_params/load_params. However, we still had a gap. Namely, when using
pickle directly, i.e. when going through __getstate__ and __setstate__,
we are still using torch.load and torch.save without handling
weights_only. This will cause trouble in the future when the default is
switched. But it's also annoying right now, because users will get the
FutureWarning about weights_only, even if they correctly pass
torch_load_kwargs (see #1090).

The reason why we use torch.save/torch.load for pickle is that those
functions are basically _extended_ pickle functions that have the
benefit of supporting the map_location argument to handle the device of
torch tensors, which we don't have for pickle. The map_location argument
is important, e.g. when saving a net that uses CUDA and loading it on a
machine without CUDA, we would otherwise run into an error.

However, with the move to weights_only=True, these torch.save/torch.load
will become _reduced_ pickle functions, as they will only support a
small subset of objects by default. Therefore, we wouldn't be able to
rely on torch.save/torch.load for pickling the whole skorch object.

In this PR, we thus move to using plain pickle for this. However, now we
run into the issue of how to handle the map_location. The solution I
ended up with is now to intercept torch's _load_from_bytes using a
custom Unpickler, and to specifically use torch.load there. That way, we
can pass the map_location and other torch_load_kwargs. The remaining
unpickling process just works as normal.

Yes, this is a private function, so we cannot be sure if it'll work
indefinitely, If there is a better suggestion, I'm open to it. However,
the function has existed for 7 years, so it's not very likely that it
will change anytime soon:

https://github.com/pytorch/pytorch/blame/0674ab7e33c3f627ca6781ce98468ec1dd4743a5/torch/storage.py#L525

A drawback of the solution is that we cannot just load old skorch nets
that were saved with torch.save using pickle.load. This is because torch
uses custom persistent_load functions. When trying to load with pickle,
we thus get:

_pickle.UnpicklingError: A load persistent id instruction was encountered, but no persistent_load function was specified.

Therefore, I had to keep torch.load as a fallback to avoid backwards
incompatibility. The bad news is that the initial problem persists,
namely that even when passing torch_load_kwargs, users get the
FutureWarning about weights_only. The good news is that users can just
re-save their net with the new skorch version and from then on they
won't see the warning again.

Note that I didn't add a specific test for this problem of loading
backwards nets from before the change, because test_pickle_load, which
uses a checked in pickle file, already covers this.

Other considered solutions:

1. Why not continue using torch.save/torch.load and just pass the
torch_load_kwargs argument to it? This is unforunately not that easy.
When switching to weights_only=True, torch will refuse to load any
custom objects, e.g. class MyModule. There is a way to prevent that,
namely via torch.serialization.add_safe_globals, but it is a ton of work
to add all required objects there, as even builtin Python types are
mostly not supported.
2. We cannot use with torch.device, as this is not honored during
unpickling.
3. During __getstate__, we could recursively go through the state, pop
all torch tensors, and replace them with, say, numpy arrays and
additional meta data like the device, then use this info to restore
those objects during __setstate__. Even though this looks like a cleaner
solution, it is much more complex and therefore, I'd argue more error
prone.
4. Don't do anything and just live with the warning: This will work -- 
until PyTorch switches the default. Therefore, we had to tackle this 
sooner or later.

Notes

While working on this, I thought that we could most likely remove the
cuda_dependent_attributes_ (which contains the net.module_,
net.optimizer_, etc.). Their purpose was to call torch.load on these
attributes specifically, but with the new Unpickler, it should also work
without this. However, I kept the attribute for now, mainly for these
reasons:

1. I didn't want to change more than necessary, as these changes are
delicate and I don't to break any existing skorch code or pickle files.
2. The attribute itself is public, so in theory, users may rely on its
existence (not sure if in practice). We would thus have to keep most of
the code related to this attribute.

But LMK if you think we should deprecate and eventually remove this
attribute.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants