You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi authors,
I'm using Mamba for my 3D segmentation model. The Mamba block receives input with shape (B, L, D), where L is relatively small (e.g., 30) and D is large (e.g., 1600). When I compute the number of trainable parameters using: "sum(p.numel() for p in model.parameters() if p.requires_grad)". I get an unexpectedly large count (300M - 1B). However, despite this high parameter count, I can train my model with a batch size of up to 12 on a single NVIDIA 3090 (24GB VRAM) without memory issues, and the training speed is also quite fast. This seems unusual.
I have double-checked my model implementation and confirmed that the code is correct. Do you have any insights into what might be causing this?
Thank you for your support!
The text was updated successfully, but these errors were encountered:
Hi authors,
I'm using Mamba for my 3D segmentation model. The Mamba block receives input with shape (B, L, D), where L is relatively small (e.g., 30) and D is large (e.g., 1600). When I compute the number of trainable parameters using: "sum(p.numel() for p in model.parameters() if p.requires_grad)". I get an unexpectedly large count (300M - 1B). However, despite this high parameter count, I can train my model with a batch size of up to 12 on a single NVIDIA 3090 (24GB VRAM) without memory issues, and the training speed is also quite fast. This seems unusual.
I have double-checked my model implementation and confirmed that the code is correct. Do you have any insights into what might be causing this?
Thank you for your support!
The text was updated successfully, but these errors were encountered: