Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Flex Decoding #196

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Integrate Flex Decoding #196

wants to merge 11 commits into from

Conversation

BoyuanFeng
Copy link

This PR integrates flex decoding with gpt-fast.

End-to-end performance gain of Llama2-7b

Device: H100
Unit: tokens/sec

Length spda Flex Decoding Speedup
1024 143.57 142.59 0.99x
2048 138.67 140.13 1.01x
4096 128.85 135.31 1.05x
8192 111.59 125.38 1.12x
16384 89 109.8 1.23x
32768 64.11 87.67 1.37x

command:

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python generate.py --compile --max_new_tokens 16384 --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is"

Please also set ModelArgs.block_size = 65536 to repeat the result.

We expect to see larger speedup on longer context length.

@BoyuanFeng BoyuanFeng self-assigned this Aug 21, 2024
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 21, 2024
model.py Outdated Show resolved Hide resolved
model.py Outdated Show resolved Hide resolved
model.py Outdated
@@ -89,7 +103,7 @@ def update(self, input_pos, k_val, v_val):
return k_out, v_out

class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
def __init__(self, config: ModelArgs, get_mask_mod: Callable[[int], _mask_mod_signature]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_mask_mod shouldn't take an integer - it should take a mask_mod. We also don't need to set at as an argument, just set it as an attribute within the module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, you should be able to take any existing mask_mod and wrap it to make it automatically support an offset.

generate.py Outdated Show resolved Hide resolved
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good (modulo some other nits), although we'll probably want to wait on landing this until the other PRs on pytorch core are landed at least.

@merrymercy
Copy link

merrymercy commented Aug 30, 2024

This looks interesting. I would like to share some numbers we got with torch.compile + flashinfer in sglang. It can serve as some good baselines. To run the 32k one, you need to edit the config.json with "max_position_embeddings": 65536,.

# length = 1024
# Decode.  median latency: 0.00596 s, median throughput:    167.67 token/s
python3 -m sglang.bench_latency --model meta-llama/Llama-2-7b-chat-hf --batch-size 1 --input 1024 --output 8 --enable-torch-compile

# length = 32768
# Decode.  median latency: 0.01136 s, median throughput:     88.01 token/s
python3 -m sglang.bench_latency --model meta-llama/Llama-2-7b-chat-hf --batch-size 1 --input 32768 --output 8 --enable-torch-compile

You can find more numbers at sgl-project/sglang#1008

@Chillee
Copy link
Contributor

Chillee commented Aug 30, 2024

@merrymercy We run on nerfed H100s internally at Meta with only 2.4 TB/s of bandwidth, so these numbers aren't 1:1 comparable.

But it's a good comparison :)

generate.py Outdated Show resolved Hide resolved
generate.py Show resolved Hide resolved
logits = model(x, input_pos)
block_index = input_pos // block_mask.BLOCK_SIZE[0]
mask = block_mask[:, :, block_index]
mask.mask_mod = block_mask.mask_mod
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline discussed that BlockMask getitem sets mask_mod as None and the user needs to specify the correct mask_mod. In GPT-Fast, we rely on model.get_mask_mod to do so.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 4, 2024
…compile (#134627)

Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue #134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196).

Pull Request resolved: #134627
Approved by: https://github.com/Chillee
pytorchmergebot pushed a commit to mori360/pytorch that referenced this pull request Sep 5, 2024
…compile (pytorch#134627)

Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196).

Pull Request resolved: pytorch#134627
Approved by: https://github.com/Chillee
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
…compile (pytorch#134627)

Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196).

Pull Request resolved: pytorch#134627
Approved by: https://github.com/Chillee
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…compile (pytorch#134627)

Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue pytorch#134731. Tested in gpt-fast [pr](pytorch-labs/gpt-fast#196).

Pull Request resolved: pytorch#134627
Approved by: https://github.com/Chillee
return sample(logits, **sampling_kwargs)

def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device="cuda")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try doing
create_block_mask_compile = torch.compile(create_block_mask)

as a global

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants