Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Nov 25, 2024
1 parent 112aeff commit 1a4fe62
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 13 deletions.
1 change: 0 additions & 1 deletion test/prototype/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TORCH_VERSION_AT_LEAST_2_6,
)

from torch.sparse import SparseSemiStructuredTensor

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down
3 changes: 2 additions & 1 deletion torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ TTFT(Time to First Token) Benchmarks
20241022205645, tok/s=132.49, mem/s=1770.83 GB/s, time=1.5092 sec, peak_mem=18.61 GB, model_size=13.37 GB quant: None, sparse: semi-structured, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20241125151919, tok/s=132.38, mem/s=1987.00 GB/s, time=1.5105 sec, peak_mem=16.20 GB, model_size=15.01 GB quant: None, sparse: None, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241125151958, tok/s=129.89, mem/s=1736.04 GB/s, time=1.5389 sec, peak_mem=18.63 GB, model_size=13.37 GB quant: None, sparse: semi-structured, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241125152104, tok/s=269.73, mem/s= 822.48 GB/s, time=0.7417 sec, peak_mem= 5.03 GB, model_size= 3.05 GB quant: int4wo-64, sparse: semi-structured, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization int4wo-64 --sparsity semi-structured --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241125152104, tok/s=269.73, mem/s= 822.48 GB/s, time=0.7417 sec, peak_mem= 5.03 GB, model_size= 3.05 GB quant: int4wo-64, sparse: semi-structured, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization int4wo-64 --sparsity semi-structured --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241125152925, tok/s= 1.74, mem/s= 26.13 GB/s, time=0.5738 sec, peak_mem=20.95 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --ttft_prefill_size 8192--num_samples 5 --max_new_tokens 1 --batch_size 1 --top_k 200 --temperature 0.8
19 changes: 15 additions & 4 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,21 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

# 2:4 sparse model
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision torch.float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision torch.float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --sparsity semi-structured --precision torch.float16 --write_result benchmark_results.txt
#export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision torch.float16 --write_result benchmark_results.txt
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision torch.float16 --write_result benchmark_results.txt
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --sparsity semi-structured --precision torch.float16 --write_result benchmark_results.txt

# TTFT benchmarks
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192 --compile_prefill
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --max_new_tokens 1 --ttft_prefill_size 8192

# Different Batch Size Benchmarks
#export MODEL_REPO=meta-llama/Meta-Llama-3-8B
Expand Down
14 changes: 7 additions & 7 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _load_model(checkpoint_path, device, precision):
B_INST, E_INST = "[INST]", "[/INST]"

def main(
prefill_size: int = 0,
ttft_prefill_size: int = 0,
prompt: str = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
Expand All @@ -183,11 +183,11 @@ def main(
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""

if prefill_size > 0:
if ttft_prefill_size > 0:
print("Running TTFT benchmark!")
assert max_new_tokens == 1, "prefill_size only supports max_new_tokens=1"
assert max_new_tokens == 1, "ttft_prefill_size only supports max_new_tokens=1"
# create prompt of prefill size ttft
prompt = "prompt " * (int(prefill_size)-3)
prompt = "prompt " * (int(ttft_prefill_size)-3)

torchao.quantization.utils.recommended_inductor_config_setter()

Expand Down Expand Up @@ -567,7 +567,7 @@ def callback(x):
result_txt += f"--precision {precision} "
result_txt += f"--compile " if compile else ""
result_txt += f"--compile_prefill " if compile_prefill else ""
result_txt += f"--prefill_size {prefill_size}" if prefill_size else ""
result_txt += f"--ttft_prefill_size {ttft_prefill_size}" if ttft_prefill_size else ""
result_txt += f"--profile {profile} " if profile else ""
result_txt += f"--profile {memory_profile} " if memory_profile else ""
result_txt += f"--interactive " if interactive else ""
Expand All @@ -589,7 +589,7 @@ def callback(x):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')
parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode')
parser.add_argument('--ttft_prefill_size', type=int, default=0, help='Whether to run in ttft mode')
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
Expand Down Expand Up @@ -625,6 +625,6 @@ def callback(x):
args = parser.parse_args()
print(args)
main(
args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.ttft_prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)

0 comments on commit 1a4fe62

Please sign in to comment.