diff --git a/modulus/datapipes/gnn/vortex_shedding_dataset.py b/modulus/datapipes/gnn/vortex_shedding_dataset.py index a70c0670e4..b83b2c63dc 100644 --- a/modulus/datapipes/gnn/vortex_shedding_dataset.py +++ b/modulus/datapipes/gnn/vortex_shedding_dataset.py @@ -207,25 +207,19 @@ def __len__(self): return self.length def _get_edge_stats(self): + edge_mean = 0 + edge_meansqr = 0 + for i in range(self.num_samples): + edge_mean += torch.mean(self.graphs[i].edata["x"], dim=0) + edge_meansqr += torch.mean(torch.square(self.graphs[i].edata["x"]), dim=0) + edge_mean /= self.num_samples + edge_meansqr /= self.num_samples + edge_std = torch.sqrt(edge_meansqr - torch.square(edge_mean)) stats = { - "edge_mean": 0, - "edge_meansqr": 0, + "edge_mean": edge_mean, + "edge_std": edge_std, } - for i in range(self.num_samples): - stats["edge_mean"] += ( - torch.mean(self.graphs[i].edata["x"], dim=0) / self.num_samples - ) - stats["edge_meansqr"] += ( - torch.mean(torch.square(self.graphs[i].edata["x"]), dim=0) - / self.num_samples - ) - stats["edge_std"] = torch.sqrt( - stats["edge_meansqr"] - torch.square(stats["edge_mean"]) - ) - stats.pop("edge_meansqr") - - # save to file - save_json(stats, "edge_stats.json") + save_json(stats, 'edge_stats.json') return stats def _get_node_stats(self): @@ -379,13 +373,8 @@ def _push_forward_diff(invar): @staticmethod def _get_rollout_mask(node_type): - mask = torch.logical_or( - torch.eq(node_type, torch.zeros_like(node_type)), - torch.eq( - node_type, - torch.zeros_like(node_type) + 5, - ), - ) + zeros = torch.zeros_like(node_type) + mask = torch.logical_or(torch.eq(node_type, zeros), torch.eq(node_type, zeros + 5)) return mask @staticmethod