Skip to content

Commit

Permalink
Update torch_kitnet.py, added torch
Browse files Browse the repository at this point in the history
  • Loading branch information
swainsubrat authored May 2, 2024
1 parent b60de93 commit e9267d9
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions titli/ids/torch_kitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(self, clusters: list, norms_path: str):
super(TorchKitNET, self).__init__()
self.dataset = "PcapDatasetRaw"
self.input_dim = sum([len(c) for c in clusters])
self.raw = True
self.hr = 0.75
self.clusters = clusters
self.rmse = RMSELoss()
Expand All @@ -89,13 +88,12 @@ def __init__(self, clusters: list, norms_path: str):
self.norm_params = pickle.load(f)

def forward(self, x):
x = torch.tensor(x)
x = x.view(-1, self.input_dim)

x_clusters = []
for c in self.clusters:
norm_max = self.norm_params[f"norm_max_{c[0]}"]
norm_min = self.norm_params[f"norm_min_{c[0]}"]
norm_max = torch.tensor(self.norm_params[f"norm_max_{c[0]}"])
norm_min = torch.tensor(self.norm_params[f"norm_min_{c[0]}"])

x_cluster = torch.index_select(x, 1, torch.tensor(c))
x_cluster = (x_cluster - norm_min) / (norm_max - norm_min + 0.0000000000000001)
Expand Down

0 comments on commit e9267d9

Please sign in to comment.