-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathlstm.py
147 lines (117 loc) · 4.63 KB
/
lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import logging
# Silence tensorflow
import os
import joblib
import numpy as np
import orjson
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.getLogger("tensorflow").setLevel(logging.FATAL)
from tensorflow.keras.layers import LSTM # noqa: E402
from tensorflow.keras.layers import Dense # noqa: E402
from tensorflow.keras.models import Sequential # noqa: E402
from tensorflow.keras.optimizers import Adam # noqa: E402
import ipal_iids.settings as settings # noqa: E402
from combiner.combiner import Combiner # noqa: E402
class LSTMCombiner(Combiner):
_name = "LSTM"
_description = "Learns a time-aware LSTM combiner."
_requires_training = True
_lstm_default_settings = {
"epochs": 20,
# Overall, the combiner looks back lookback * stride data points
"lookback": 30,
"stride": 1,
"use_scores": False,
"verbose": 0,
}
def __init__(self):
super().__init__()
self._add_default_settings(self._lstm_default_settings)
self.model = None
self.keys = None
self.buffer = []
self.window_size = self.settings["lookback"] * self.settings["stride"]
def _get_activations(self, alerts, scores):
data = scores if self.settings["use_scores"] else alerts
if not set(data.keys()) == set(self.keys):
settings.logger.error("Keys of combiner do not match data")
settings.logger.error(f"- data keys: {','.join(data.keys())}")
settings.logger.error(f"- combiner keys: {','.join(self.keys)}")
exit(1)
return [float(data[ids]) for ids in self.keys]
def _lstm_model(self, input_dim):
model = Sequential()
model.add(LSTM(input_dim, input_shape=(self.settings["lookback"], input_dim)))
model.add(Dense(1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer=Adam(), metrics=["acc"])
model.summary(print_fn=settings.logger.info)
return model
def train(self, file):
buffer = []
seq = []
annotations = []
settings.logger.info("Loading combiner training file")
with self._open_file(file, "r") as f:
for line in f:
js = orjson.loads(line)
if self.keys is None:
self.keys = sorted(js["scores"].keys())
# Manage buffer
buffer.append(self._get_activations(js["alerts"], js["scores"]))
if len(buffer) < self.window_size:
continue
elif len(buffer) > self.window_size:
buffer.pop(0)
# Add training sequence
seq.append(np.array(buffer[:: -self.settings["stride"]]))
annotations.append(js["malicious"] is not False)
self.model = self._lstm_model(len(self.keys))
settings.logger.info(f"Training LSTM for {self.settings['epochs']} epochs...")
self.model.fit(
np.array(seq),
np.array(annotations),
epochs=self.settings["epochs"],
verbose=self.settings["verbose"],
)
def combine(self, alerts, scores):
# Manage buffer
self.buffer.append(self._get_activations(alerts, scores))
if len(self.buffer) < self.window_size:
return False, 0, 0
elif len(self.buffer) > self.window_size:
self.buffer.pop(0)
sequence = self.buffer[:: -self.settings["stride"]]
prediction = float(
self.model.predict(np.array([sequence]), verbose=False)[0][0]
)
return prediction > 0.5, prediction, 0
def save_trained_model(self):
if self.settings["model-file"] is None:
return False
model = {
"_name": self._name,
"settings": self.settings,
"model": self.model,
"keys": self.keys,
"window_size": self.window_size,
}
joblib.dump(model, self._resolve_model_file_path(), compress=3)
return True
def load_trained_model(self):
if self.settings["model-file"] is None:
return False
try: # Open model file
model = joblib.load(self._resolve_model_file_path())
except FileNotFoundError:
settings.logger.info(
f"Model file {str(self._resolve_model_file_path())} not found."
)
return False
# Load model
assert self._name == model["_name"]
self.settings = model["settings"]
self.model = model["model"]
self.keys = model["keys"]
self.window_size = model["window_size"]
self.buffer = []
return True