This repository has been archived by the owner on Jul 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathrandom.py
93 lines (72 loc) · 2.89 KB
/
random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from contextlib import contextmanager
import functools
import random
import typing
import numpy as np
import torch
class RandomGeneratorState(typing.NamedTuple):
random: typing.Tuple[typing.Any]
torch: torch.Tensor
numpy: typing.Tuple[typing.Any]
torch_cuda: typing.Optional[typing.Tuple[typing.Any]]
def get_random_generator_state(cuda: bool = torch.cuda.is_available()) -> RandomGeneratorState:
""" Get the `torch`, `numpy` and `random` random generator state.
Args:
cuda (bool, optional): If `True` saves the `cuda` seed also. Note that getting and setting
the random generator state for CUDA can be quite slow if you have a lot of GPUs.
Returns:
RandomGeneratorState
"""
return RandomGeneratorState(random.getstate(), torch.random.get_rng_state(),
np.random.get_state(),
torch.cuda.get_rng_state_all() if cuda else None)
def set_random_generator_state(state: RandomGeneratorState):
""" Set the `torch`, `numpy` and `random` random generator state.
Args:
state (RandomGeneratorState)
"""
random.setstate(state.random)
torch.random.set_rng_state(state.torch)
np.random.set_state(state.numpy)
if state.torch_cuda is not None and torch.cuda.is_available() and len(
state.torch_cuda) == torch.cuda.device_count(): # pragma: no cover
torch.cuda.set_rng_state_all(state.torch_cuda)
@contextmanager
def fork_rng(seed=None, cuda=torch.cuda.is_available()):
""" Forks the `torch`, `numpy` and `random` random generators, so that when you return, the
random generators are reset to the state that they were previously in.
Args:
seed (int or None, optional): If defined this sets the seed values for the random
generator fork. This is a convenience parameter.
cuda (bool, optional): If `True` saves the `cuda` seed also. Getting and setting the random
generator state can be quite slow if you have a lot of GPUs.
"""
state = get_random_generator_state(cuda)
if seed is not None:
set_seed(seed, cuda)
try:
yield
finally:
set_random_generator_state(state)
def fork_rng_wrap(function=None, **kwargs):
""" Decorator alias for `fork_rng`.
"""
if not function:
return functools.partial(fork_rng_wrap, **kwargs)
@functools.wraps(function)
def wrapper():
with fork_rng(**kwargs):
return function()
return wrapper
def set_seed(seed, cuda=torch.cuda.is_available()):
""" Set seed values for random generators.
Args:
seed (int): Value used as a seed.
cuda (bool, optional): If `True` sets the `cuda` seed also.
"""
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if cuda: # pragma: no cover
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)