You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to train a BiStride MeshGraphNet on my dataset (very similar to DrivAerNet), but I keep getting errors. It looks like it's expecting the data in the graph to have a very specific structure, unlike MeshGraphNet which is better written (and it trains on my data). The error I'm getting is
Traceback (most recent call last):
File "/workspace/.../test_bsms_mgn.py", line 292, in <module>
batch_loss = trainer.train(graph['graph'])
File "/workspace/..../test_bsms_mgn.py", line 245, in train
loss = self.forward(graph)
File "/workspace/.../test_bsms_mgn.py", line 251, in forward
pred = self.model(graph.ndata["x"], graph.edata["x"], graph)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1714, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1725, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/modulus/models/meshgraphnet/bsms_mgn.py", line 165, in forward
x = self.bistride_processor(x, ms_ids, ms_edges, node_pos)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1714, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1725, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/modulus/models/gnn_layers/bsms.py", line 291, in forward
h = self.down_gmps[i](h, m_gs[i], pos)
IndexError: list index out of range
You are correct, BSMS MGN expects the data in a certain format. To enable this format, you need to wrap your dataset class in BistrideMultiLayerGraphDataset like it's done in the Ahmed body example. You can do this either in the code or by using Hydra config - check out the BSMS Ahmed body experiment and corresponding dataset config.
I'm not sure I understand. Do you mean that, if I want to test both MeshGraphNet and BSMS MGN on the same data, I need to write two different dataset classes? That's not great from a SWE point of view - I'd like my dataset class to be independent of the model class, as much as possible. Of course, complete decoupling is not realistic (if I want to test a set of GNN models, I expect the Dataset class to have a graph building method), but having to write a different class for each model I want to test is definitely suboptimal. Maybe I didn't understand your suggestion?
You don't need to write a new dataset class, all you have to do is to wrap your existing dataset class with BistrideMultiLayerGraphDataset class, like it's demonstrated in the config I mentioned in my response.
Specifically, in that config example, the already existing Ahmed Body dataset class, AhmedBodyDataset, is wrapped by BistrideMultiLayerGraphDataset. So in your case, all you have to do is provide your own, already existing, class instead of AhmedBodyDataset.
If you prefer doing it from the code rather than Hydra config, then the code will roughly look something like:
Version
0.8.0
On which installation method(s) does this occur?
Docker
Describe the issue
I'm trying to train a BiStride MeshGraphNet on my dataset (very similar to DrivAerNet), but I keep getting errors. It looks like it's expecting the data in the graph to have a very specific structure, unlike MeshGraphNet which is better written (and it trains on my data). The error I'm getting is
Can you help? It would be useful if you provided an example to test BiStride MeshGraphNet out, but the only example mentioned in the documentation regards the Ahmed body dataset which is not included in the
examples
folder.https://docs.nvidia.com/deeplearning/modulus/modulus-core/examples/cfd/aero_graph_net/readme.html#bsms-mgn-training
Minimum reproducible example
This is the dataset class:
And this is the
__init__
method of my trainer class:Relevant log output
No response
Environment details
No response
The text was updated successfully, but these errors were encountered: