Skip to content

Commit

Permalink
FIX Unpickling without using torch.load (#1092)
Browse files Browse the repository at this point in the history
Resolves #1090.

PyTorch plans to make the switch to weights_only=True for torch.load. We
already partly dealt with that in #1064 when it comes to
save_params/load_params. However, we still had a gap. Namely, when using
pickle directly, i.e. when going through __getstate__ and __setstate__,
we are still using torch.load and torch.save without handling
weights_only. This will cause trouble in the future when the default is
switched. But it's also annoying right now, because users will get the
FutureWarning about weights_only, even if they correctly pass
torch_load_kwargs (see #1090).

The reason why we use torch.save/torch.load for pickle is that those
functions are basically _extended_ pickle functions that have the
benefit of supporting the map_location argument to handle the device of
torch tensors, which we don't have for pickle. The map_location argument
is important, e.g. when saving a net that uses CUDA and loading it on a
machine without CUDA, we would otherwise run into an error.

However, with the move to weights_only=True, these torch.save/torch.load
will become _reduced_ pickle functions, as they will only support a
small subset of objects by default. Therefore, we wouldn't be able to
rely on torch.save/torch.load for pickling the whole skorch object.

In this PR, we thus move to using plain pickle for this. However, now we
run into the issue of how to handle the map_location. The solution I
ended up with is now to intercept torch's _load_from_bytes using a
custom Unpickler, and to specifically use torch.load there. That way, we
can pass the map_location and other torch_load_kwargs. The remaining
unpickling process just works as normal.

Yes, this is a private function, so we cannot be sure if it'll work
indefinitely, If there is a better suggestion, I'm open to it. However,
the function has existed for 7 years, so it's not very likely that it
will change anytime soon:

https://github.com/pytorch/pytorch/blame/0674ab7e33c3f627ca6781ce98468ec1dd4743a5/torch/storage.py#L525

A drawback of the solution is that we cannot just load old skorch nets
that were saved with torch.save using pickle.load. This is because torch
uses custom persistent_load functions. When trying to load with pickle,
we thus get:

_pickle.UnpicklingError: A load persistent id instruction was encountered, but no persistent_load function was specified.

Therefore, I had to keep torch.load as a fallback to avoid backwards
incompatibility. The bad news is that the initial problem persists,
namely that even when passing torch_load_kwargs, users get the
FutureWarning about weights_only. The good news is that users can just
re-save their net with the new skorch version and from then on they
won't see the warning again.

Note that I didn't add a specific test for this problem of loading
backwards nets from before the change, because test_pickle_load, which
uses a checked in pickle file, already covers this.

Other considered solutions:

1. Why not continue using torch.save/torch.load and just pass the
torch_load_kwargs argument to it? This is unforunately not that easy.
When switching to weights_only=True, torch will refuse to load any
custom objects, e.g. class MyModule. There is a way to prevent that,
namely via torch.serialization.add_safe_globals, but it is a ton of work
to add all required objects there, as even builtin Python types are
mostly not supported.
2. We cannot use with torch.device, as this is not honored during
unpickling.
3. During __getstate__, we could recursively go through the state, pop
all torch tensors, and replace them with, say, numpy arrays and
additional meta data like the device, then use this info to restore
those objects during __setstate__. Even though this looks like a cleaner
solution, it is much more complex and therefore, I'd argue more error
prone.
4. Don't do anything and just live with the warning: This will work -- 
until PyTorch switches the default. Therefore, we had to tackle this 
sooner or later.

Notes

While working on this, I thought that we could most likely remove the
cuda_dependent_attributes_ (which contains the net.module_,
net.optimizer_, etc.). Their purpose was to call torch.load on these
attributes specifically, but with the new Unpickler, it should also work
without this. However, I kept the attribute for now, mainly for these
reasons:

1. I didn't want to change more than necessary, as these changes are
delicate and I don't to break any existing skorch code or pickle files.
2. The attribute itself is public, so in theory, users may rely on its
existence (not sure if in practice). We would thus have to keep most of
the code related to this attribute.

But LMK if you think we should deprecate and eventually remove this
attribute.
  • Loading branch information
BenjaminBossan authored Jan 27, 2025
1 parent bb1bac4 commit be93b77
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
### Changed

- Loading of skorch nets using pickle: When unpickling a skorch net, you may come across a PyTorch warning that goes: "FutureWarning: You are using torch.load with weights_only=False [...]"; to avoid this warning, pickle the net again and use the new pickle file (#1092)

### Fixed

## [1.1.0]
Expand Down
21 changes: 19 additions & 2 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import OrderedDict
from contextlib import contextmanager
import os
import pickle
import tempfile
import warnings

Expand All @@ -33,6 +34,7 @@
from skorch.exceptions import SkorchTrainingImpossibleError
from skorch.history import History
from skorch.setter import optimizer_setter
from skorch.utils import _TorchLoadUnpickler
from skorch.utils import _identity
from skorch.utils import _infer_predict_nonlinearity
from skorch.utils import FirstStepAccumulator
Expand Down Expand Up @@ -2242,7 +2244,7 @@ def __getstate__(self):
state.pop(k)

with tempfile.SpooledTemporaryFile() as f:
torch.save(cuda_attrs, f)
pickle.dump(cuda_attrs, f)
f.seek(0)
state['__cuda_dependent_attributes__'] = f.read()

Expand All @@ -2254,11 +2256,26 @@ def __setstate__(self, state):
map_location = get_map_location(state['device'])
load_kwargs = {'map_location': map_location}
state['device'] = self._check_device(state['device'], map_location)
torch_load_kwargs = state.get('torch_load_kwargs') or get_default_torch_load_kwargs()

with tempfile.SpooledTemporaryFile() as f:
unpickler = _TorchLoadUnpickler(
f,
map_location=map_location,
torch_load_kwargs=torch_load_kwargs,
)
f.write(state['__cuda_dependent_attributes__'])
f.seek(0)
cuda_attrs = torch.load(f, **load_kwargs)
try:
cuda_attrs = unpickler.load()
except pickle.UnpicklingError:
# This object was saved using skorch from before switching to the
# custom unpickler, i.e. with torch.save. Fall back to the old loading
# code using torch.load. Unfortunately, this means that the user may
# get the FutureWarning about weights_only=False. They need to re-save
# the net to get rid of the warning
f.seek(0)
cuda_attrs = torch.load(f, **load_kwargs)

state.update(cuda_attrs)
state.pop('__cuda_dependent_attributes__')
Expand Down
35 changes: 35 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3081,6 +3081,7 @@ def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
# See discussion in 1063.
from skorch._version import Version

# TODO remove once torch 2.5.0 is no longer supported
if Version(torch.__version__) >= Version('2.6.0'):
pytest.skip("Test only for torch < v2.6.0")

Expand All @@ -3097,6 +3098,40 @@ def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_forwarded_to_torch_load_unpickle(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# See discussion in 1090
# Here we check that custom set torch load args are forwarded to
# torch.load even when using pickle. This is the same test otherwise as
# test_torch_load_kwargs_forwarded_to_torch_load
expected_kwargs = {'weights_only': 123, 'foo': 'bar'}
net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize()

original_torch_load = torch.load
# call original torch.load without extra params to prevent error:
mock_torch_load = Mock(
side_effect=lambda *args, **kwargs: original_torch_load(*args)
)
monkeypatch.setattr(torch, "load", mock_torch_load)
dumped = pickle.dumps(net)
pickle.loads(dumped)

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_unpickle_no_pytorch_warning(self, net_cls, module_cls, recwarn):
# See discussion 1090
# When using pickle, i.e. when going through __setstate__, we don't want to get
# any warnings about the usage of weights_only.
net = net_cls(module_cls).initialize()
dumped = pickle.dumps(net)
pickle.loads(dumped)

msg_content = "weights_only"
assert not any(msg_content in str(w.message) for w in recwarn.list)

def test_custom_module_params_passed_to_optimizer(
self, net_custom_module_cls, module_cls):
# custom module parameters should automatically be passed to the optimizer
Expand Down
34 changes: 33 additions & 1 deletion skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import io
from itertools import tee
import pathlib
import pickle
import warnings

import numpy as np
from scipy import sparse
import sklearn
from sklearn.exceptions import NotFittedError
from sklearn.utils import _safe_indexing as safe_indexing
from sklearn.utils.validation import check_is_fitted as sk_check_is_fitted
Expand Down Expand Up @@ -784,3 +784,35 @@ def get_default_torch_load_kwargs():
if version_torch >= version_default_switch:
return {"weights_only": True}
return {"weights_only": False}


class _TorchLoadUnpickler(pickle.Unpickler):
"""
Subclass of pickle.Unpickler that intercepts 'torch.storage._load_from_bytes' calls
and uses `torch.load(..., map_location=..., torch_load_kwargs=...)`.
This way, we can use normal pickle when unpickling a skorch net but still benefit
from torch.load to handle the map_location. Note that `with torch.device(...)` does
not work for unpickling.
"""

def __init__(self, *args, map_location, torch_load_kwargs, **kwargs):
super().__init__(*args, **kwargs)
self.map_location = map_location
self.torch_load_kwargs = torch_load_kwargs

def find_class(self, module, name):
# The actual serialized data for PyTorch tensors references
# torch.storage._load_from_bytes internally. We intercept that call:
if (module == 'torch.storage') and (name == '_load_from_bytes'):
# Return a function that uses torch.load with our desired map_location
def _load_from_bytes(b):
return torch.load(
io.BytesIO(b),
map_location=self.map_location,
**self.torch_load_kwargs
)
return _load_from_bytes

return super().find_class(module, name)

0 comments on commit be93b77

Please sign in to comment.