-
Notifications
You must be signed in to change notification settings - Fork 515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate Flex Decoding #196
base: main
Are you sure you want to change the base?
Conversation
model.py
Outdated
@@ -89,7 +103,7 @@ def update(self, input_pos, k_val, v_val): | |||
return k_out, v_out | |||
|
|||
class Transformer(nn.Module): | |||
def __init__(self, config: ModelArgs) -> None: | |||
def __init__(self, config: ModelArgs, get_mask_mod: Callable[[int], _mask_mod_signature]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_mask_mod
shouldn't take an integer - it should take a mask_mod
. We also don't need to set at as an argument, just set it as an attribute within the module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically, you should be able to take any existing mask_mod
and wrap it to make it automatically support an offset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good (modulo some other nits), although we'll probably want to wait on landing this until the other PRs on pytorch core are landed at least.
This looks interesting. I would like to share some numbers we got with torch.compile + flashinfer in sglang. It can serve as some good baselines. To run the 32k one, you need to edit the
You can find more numbers at sgl-project/sglang#1008 |
@merrymercy We run on nerfed H100s internally at Meta with only 2.4 TB/s of bandwidth, so these numbers aren't 1:1 comparable. But it's a good comparison :) |
logits = model(x, input_pos) | ||
block_index = input_pos // block_mask.BLOCK_SIZE[0] | ||
mask = block_mask[:, :, block_index] | ||
mask.mask_mod = block_mask.mask_mod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offline discussed that BlockMask getitem sets mask_mod as None and the user needs to specify the correct mask_mod. In GPT-Fast, we rely on model.get_mask_mod
to do so.
…compile (#134627) Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue #134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196). Pull Request resolved: #134627 Approved by: https://github.com/Chillee
…compile (pytorch#134627) Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196). Pull Request resolved: pytorch#134627 Approved by: https://github.com/Chillee
…compile (pytorch#134627) Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196). Pull Request resolved: pytorch#134627 Approved by: https://github.com/Chillee
…compile (pytorch#134627) Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196). Pull Request resolved: pytorch#134627 Approved by: https://github.com/Chillee
return sample(logits, **sampling_kwargs) | ||
|
||
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): | ||
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try doing
create_block_mask_compile = torch.compile(create_block_mask)
as a global
This PR integrates flex decoding with gpt-fast.
End-to-end performance gain of Llama2-7b
Device: H100
Unit: tokens/sec
command:
Please also set
ModelArgs.block_size = 65536
to repeat the result.We expect to see larger speedup on longer context length.