-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnpzload.py
77 lines (64 loc) · 2.37 KB
/
npzload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (C) 2022 Giacomo Petrillo
# Released under the MIT license
import numpy as np
import updowncast
class NPZLoad:
"""
Superclass for adding automatic serialization to/from npz files.
Only scalar/array instance variables are serialized. Class attributes and
dunders are ignored.
Instance methods
----------------
save : save the object to a file.
Class methods
-------------
load : read an instance form a file.
Attributes
----------
_npzload_unpack_scalars : bool
If True, 0d arrays are converted to scalars when loading. Default False.
"""
def save(self, filename, compress=False, downcast=None):
"""
Save the object to file as a `.npz` archive.
Parameters
----------
filename : str
The file path to save to.
compress : bool
If True, compress the npz archive (slow). Default False.
downcast : numpy data type or tuple of numpy data types, optional
A list of "short" data types. Arrays (but not scalars) with a
data type compatible with one in the list, but larger, are
casted to the short type. Applies also to structured arrays.
"""
classvars = vars(type(self))
variables = {
n: x
for n, x in vars(self).items()
if n not in classvars
and not n.startswith('__')
and (np.isscalar(x) or isinstance(x, np.ndarray))
}
if downcast is not None:
if not isinstance(downcast, tuple):
downcast = (downcast,)
for n, x in variables.items():
if hasattr(x, 'dtype'):
variables[n] = np.asarray(x, updowncast.downcast(x.dtype, *downcast))
fun = np.savez_compressed if compress else np.savez
fun(filename, **variables)
_npzload_unpack_scalars = False
@classmethod
def load(cls, filename):
"""
Return an instance loading the object from a file which was written by
`save`.
"""
self = cls.__new__(cls)
with np.load(filename) as arch:
for n, x in arch.items():
if x.shape == () and self._npzload_unpack_scalars:
x = x.item()
setattr(self, n, x)
return self