-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstateless.py
123 lines (103 loc) · 3.98 KB
/
stateless.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Downloaded from the official PyTorch repository:
# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/_stateless.py
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple
import torch
from torch import Tensor
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
# and using other types causes mypy errors
def _change_class(module) -> None:
cls = module.__class__
func_params : Dict[str, Tensor] = module._functional_parameters
def _getattribute(self, name: str) -> Any:
if name in func_params:
return func_params[name]
return cls.__getattribute__(self, name)
param_cls = type(
f"StatelessReplacer{cls.__name__}",
(cls,),
{
"__getattribute__": _getattribute,
},
)
module.__class__ = param_cls
module._orig_class = cls
def _swap_parameters(module, tensor_name: str, tensor: Tensor) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, "_functional_parameters"):
module._functional_parameters[tensor_name] = tensor
else:
module._functional_parameters = {}
module._functional_parameters[tensor_name] = tensor
_change_class(module)
def _remove_swap(module, name: str) -> None:
if hasattr(module, "_orig_class"):
module.__class__ = module._orig_class
delattr(module, "_orig_class")
delattr(module, "_functional_parameters")
@contextlib.contextmanager
def reparametrize_module(
module: torch.nn.Module,
parameters_and_buffers: Dict[str, Tensor],
) -> Iterator[None]:
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_swap_parameters,
module, name.split("."), (tensor,))
yield
for name in parameters_and_buffers:
_apply_func_submodules(
_remove_swap,
module, name.split("."), ())
def _apply_func_submodules(
func: Callable[..., None],
module: torch.nn.Module,
path: List[str],
args: Tuple,
):
if len(path) == 1:
func(module, path[0], *args)
else:
_apply_func_submodules(func, getattr(module, path[0]), path[1:], args)
def functional_call(
module: torch.nn.Module,
parameters_and_buffers: Dict[str, Tensor],
args: Tuple,
kwargs : Dict[str, Any] = None,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
.. note:: If the module has active parametrizations, passing a value in the
`parameters_and_buffers` argument with the name set to the regular parameter
name will completely disable the parametrization.
If you want to apply the parametrization function to the value passed
please set the key as `{submodule_name}.parametrizations.{parameter_name}.original`.
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
the module call.
args (tuple): arguments to be passed to the module call
kwargs (dict): keyword arguments to be passed to the module call
Returns:
Any: the result of calling ``module``.
"""
# TODO allow kwargs such as unsafe and others for parametrization
if (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or isinstance(module, (
torch.jit.RecursiveScriptModule,
torch.jit.ScriptModule,
torch.jit.ScriptFunction)
)
):
raise RuntimeError("The stateless API can't be used with Jitted modules")
if kwargs is None:
kwargs = {}
with reparametrize_module(module, parameters_and_buffers):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else:
out = module(args, **kwargs)
return out