Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 Causality #700

Open
William-HYWu opened this issue Mar 5, 2025 · 6 comments
Open

Mamba2 Causality #700

William-HYWu opened this issue Mar 5, 2025 · 6 comments

Comments

@William-HYWu
Copy link

Hi. Thank you for your wonderful work! I would like to inquire about the causality of Mamba2. I think theoretically it should be causal, however, when I run the code below:

import torch
from mamba_ssm import Mamba2

torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = Mamba2(
                # This module uses roughly 3 * expand * d_model^2 parameters
                d_model=768,  # Model dimension d_model
                d_state=256,  # SSM state expansion factor
                rmsnorm=True,
                d_conv=4,  # Local convolution width
                expand=2,  # Block expansion factor
            )

model = model.cuda()
model.eval()
inputs = torch.randn(1, 128, 768).to(torch.device('cuda'))
outputs1 = model(inputs[:,:10,:])
outputs1 = outputs1.squeeze()
outputs2 = model(inputs)
outputs2 = outputs2.squeeze()[:10,:]
print(outputs1.shape)
print(outputs2.shape)
assert torch.equal(outputs1, outputs2), "Outputs are not equal"

I get AssertionError: Outputs are not equal

I have already excluded the factor of randomness since when running

import torch
from mamba_ssm import Mamba2

torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = Mamba2(
                # This module uses roughly 3 * expand * d_model^2 parameters
                d_model=768,  # Model dimension d_model
                d_state=256,  # SSM state expansion factor
                rmsnorm=True,
                d_conv=4,  # Local convolution width
                expand=2,  # Block expansion factor
            )

model = model.cuda()
model.eval()
inputs = torch.randn(1, 128, 768).to(torch.device('cuda'))
outputs1 = model(inputs[:,:10,:])
outputs1 = outputs1.squeeze()
outputs2 = model(inputs[:,:10,:])
outputs2 = outputs2.squeeze()
print(outputs1.shape)
print(outputs2.shape)
assert torch.equal(outputs1, outputs2), "Outputs are not equal"

The assertion passed.
Is there any extra argument I need to add to make it causal?
Thank you for your help.

@tridao
Copy link
Collaborator

tridao commented Mar 5, 2025

How large is the difference?

@William-HYWu
Copy link
Author

How large is the difference?

About 1e-7 to 1e-6, I'm suspecting it is due to some floating point precision rather than the model itself.

@peterbjorgensen
Copy link

Why would these be the same? The hidden states should be different after processing 10 items in the sequence. It is not a linear time-invariant system.

@William-HYWu
Copy link
Author

Why would these be the same? The hidden states should be different after processing 10 items in the sequence. It is not a linear time-invariant system.

Yes, but I'm comparing the model's output rather than its hidden state. Since the model is causal, the first 10 outputs should remain the same regardless of sequence length, as no future information is used.

That said, the difference I observed is extremely small, which makes me inclined to believe the assertion failed due to inherent GPU precision variations rather than a fundamental issue with the model.

@peterbjorgensen
Copy link

Yes, sorry, I misread your code. That's interesting. Have you tried setting the chunk_size parameter to 1. Mamba splits the input into chunks and process them in parallel, then recombines them, so there might be some numerical noise depending on the chunking.

@tridao
Copy link
Collaborator

tridao commented Mar 5, 2025

How large is the difference?

About 1e-7 to 1e-6, I'm suspecting it is due to some floating point precision rather than the model itself.

That's probably fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants