Skip to content

Commit

Permalink
Add the changes to run the tensor on the cuda. Made change in forward…
Browse files Browse the repository at this point in the history
… function of the class TorchKitNET
  • Loading branch information
raghavs821 committed May 24, 2024
1 parent fd38e61 commit add1e1f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions titli/ids/torch_kitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def forward(self, x):

x_clusters = []
for c in self.clusters:
norm_max = torch.tensor(self.norm_params[f"norm_max_{c[0]}"])
norm_min = torch.tensor(self.norm_params[f"norm_min_{c[0]}"])
norm_max = torch.tensor(self.norm_params[f"norm_max_{c[0]}"]).to(x.device)
norm_min = torch.tensor(self.norm_params[f"norm_min_{c[0]}"]).to(x.device)

x_cluster = torch.index_select(x, 1, torch.tensor(c))
x_cluster = torch.index_select(x, 1, torch.tensor(c).to(x.device))
x_cluster = (x_cluster - norm_min) / (norm_max - norm_min + 0.0000000000000001)
x_cluster = x_cluster.float()

Expand All @@ -114,10 +114,11 @@ def forward(self, x):
tails = torch.stack(tail_losses)

# nomalize the tails
norm_max = torch.tensor(self.norm_params["norm_max_output"])
norm_min = torch.tensor(self.norm_params["norm_min_output"])
norm_max = torch.tensor(self.norm_params["norm_max_output"]).to(x.device)
norm_min = torch.tensor(self.norm_params["norm_min_output"]).to(x.device)
tails = (tails - norm_min) / (norm_max - norm_min + 0.0000000000000001)
tails = tails.float()
x = self.head(tails)

return x, tails

0 comments on commit add1e1f

Please sign in to comment.