-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnpzload.py
66 lines (57 loc) · 2.03 KB
/
npzload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import downcast as _downcast
class NPZLoad:
"""
Superclass for adding automatic serialization to/from npz files. Only
scalar/array instance variables are considered.
Instance methods
----------------
save : save the object to a file.
Class methods
-------------
load : read an instance form a file.
"""
def save(self, filename, compress=False, downcast=None):
"""
Save the object to file as a `.npz` archive.
Parameters
----------
filename : str
The file path to save to.
compress : bool
If True, compress the npz archive (slow). Default False.
downcast : numpy data type or tuple of numpy data types, optional
A list of "short" data types. Arrays (but not scalars) with a
data type compatible with one in the list, but larger, are
casted to the short type. Applies also to structured arrays.
"""
classdir = dir(type(self))
variables = {
n: x
for n, x in vars(self).items()
if n not in classdir
and not n.startswith('__')
and (np.isscalar(x) or isinstance(x, np.ndarray))
}
if downcast is not None:
if not isinstance(downcast, tuple):
downcast = (downcast,)
for n, x in variables.items():
if hasattr(x, 'dtype'):
variables[n] = np.asarray(x, _downcast.downcast(x.dtype, *downcast))
fun = np.savez_compressed if compress else np.savez
fun(filename, **variables)
@classmethod
def load(cls, filename):
"""
Return an instance loading the object from a file which was written by
`save`.
"""
self = cls.__new__(cls)
arch = np.load(filename)
for n, x in arch.items():
if x.shape == ():
x = x.item()
setattr(self, n, x)
arch.close()
return self