Skip to content

Commit

Permalink
#12120: Create llama3-70b device perf tests for prefill/decode
Browse files Browse the repository at this point in the history
Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed Sep 9, 2024
1 parent 30078ea commit 3dc78e5
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 13 deletions.
161 changes: 161 additions & 0 deletions models/demos/t3000/llama2_70b/tests/test_llama_device_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import sys
import pytest
from models.utility_functions import skip_for_grayskull
from models.demos.t3000.llama2_70b.tt.llama_common import setup_llama_env, check_mesh_device
from models.demos.t3000.llama2_70b.tests.test_llama_model import run_test_LlamaModel_inference
from models.demos.t3000.llama2_70b.tests.test_llama_model_t3000 import N_LAYERS_TO_PCC
from models.demos.t3000.llama2_70b.tests.test_llama_model import DEVICE_PERF_START_SIGNPOST
from models.demos.t3000.mixtral8x7b.scripts.op_perf_results import main as calculate_op_perf_results
from tt_metal.tools.profiler.process_model_log import run_device_profiler, get_latest_ops_log_filename
from models.perf.device_perf_utils import check_device_perf


@pytest.mark.parametrize(
"llama_version",
(("llama3"),),
)
@pytest.mark.parametrize("n_layers", (1,), ids=("1L",))
@pytest.mark.parametrize(
"batch, seq_len, generation_start_pos",
(
# Decode, batch 16
(16, 1, 127),
(16, 1, 2047),
(16, 1, 4095),
(16, 1, 8191),
# Decode, batch 32
(32, 1, 127),
(32, 1, 2047),
(32, 1, 4095),
# Prefill
(1, 128, 0),
(1, 2048, 0),
(1, 4096, 0),
(1, 8192, 0),
),
ids=(
"decode_128_batch16",
"decode_2048_batch16",
"decode_4096_batch16",
"decode_8192_batch16",
"decode_128_batch32",
"decode_2048_batch32",
"decode_4096_batch32",
"prefill_128",
"prefill_2048",
"prefill_4096",
"prefill_8192",
),
)
@skip_for_grayskull()
def test_run_device_perf_llama(
batch,
seq_len,
generation_start_pos,
n_layers,
t3k_mesh_device,
llama_version,
use_program_cache,
):
max_batch_size = batch if seq_len == 1 else 16 # max_batch_size is 16 for prefill
max_context_len = {16: 8192, 32: 4096}[max_batch_size] # set max context depending on max batch

model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env(
llama_version=llama_version,
batch=batch,
seq_len=seq_len,
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)

check_mesh_device(t3k_mesh_device, model_config)

run_test_LlamaModel_inference(
t3k_mesh_device,
batch,
seq_len,
N_LAYERS_TO_PCC[n_layers],
model_config,
n_layers,
llama_version,
ckpt_dir,
tokenizer_path,
cache_path,
generation_start_pos=generation_start_pos,
device_perf=True,
)


@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"test_id, expected_throughput",
(
("decode_128_batch16", 16.9), # Issue #9028
("decode_2048_batch16", 0), # Issue #9028
("decode_4096_batch16", 0), # Issue #9028
("decode_8192_batch16", 0), # Issue #9028
("decode_128_batch32", 16.6),
("decode_2048_batch32", 14.1),
("decode_4096_batch32", 12.8),
("prefill_128", 713),
("prefill_2048", 1036),
("prefill_4096", 1024),
("prefill_8192", 989),
),
)
@skip_for_grayskull()
def test_device_perf_llama(
test_id,
expected_throughput, # t/s for prefill, t/s/u for decode
is_ci_env,
):
if is_ci_env:
if test_id in ["decode_128_batch16", "decode_2048_batch16", "decode_4096_batch16", "decode_8192_batch16"]:
pytest.skip("Skipping on CI due to Issue #9028")

margin = 0.03
subdir = "llama3-70b"
command = (
f"pytest models/demos/t3000/llama2_70b/tests/test_llama_device_perf.py::test_run_device_perf_llama -k {test_id}"
)

# Run profiler
run_device_profiler(command, output_logs_subdir=subdir)

# Prepare the arguments to calculate the ops performance results
ops_perf_filename = get_latest_ops_log_filename(subdir)
llm_mode, seq_len, *_ = test_id.split("_")
if llm_mode == "decode":
skip_first = 3 # embeddings, i2s (embeddings), i2s (rot-mat)
skip_last = 3 # all-gather, rms-norm, lm-head
else:
skip_first = 1 # embeddings
skip_last = 5 # ln pre-all-gather, all-gather, ln post-all-gather, all-gather, matmul
n_layers_total = 80
sys.argv = [
"op_perf_results.py",
f"{ops_perf_filename}",
"--signpost",
DEVICE_PERF_START_SIGNPOST,
"--skip-first",
f"{skip_first}",
"--skip-last",
f"{skip_last}",
"--seqlen",
f"{seq_len}",
"--estimate-full-model",
f"{n_layers_total}",
]
if llm_mode == "prefill":
sys.argv.append("--prefill")

# Calculate the ops performance results using the system arguments above
measured_throughput = calculate_op_perf_results() # t/s for prefill, t/s/u for decode

check_device_perf(
{"throughput": measured_throughput}, margin, {"throughput": expected_throughput}, assert_on_fail=True
)
20 changes: 17 additions & 3 deletions models/demos/t3000/llama2_70b/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
import gc


DEVICE_PERF_START_SIGNPOST = "START_PERF_RUN"
DEVICE_PERF_END_SIGNPOST = "END_PERF_RUN"


class PytorchLlamaModel(torch.nn.Module):
def __init__(self, hf_reference_model):
super().__init__()
Expand Down Expand Up @@ -69,7 +73,12 @@ def run_test_LlamaModel_inference(
tokenizer_path,
cache_path,
prompt_file=None,
generation_start_pos=0,
device_perf=False, # set to True when measuring device perf
):
if device_perf: # Enable tracy signpost support in device perf runs only
from tracy import signpost

# Load prompt file if provided
prompt = None
if prompt_file:
Expand Down Expand Up @@ -109,11 +118,9 @@ def run_test_LlamaModel_inference(
cache_path=cache_path,
)

if model_config["LLM_MODE"] == "prefill":
generation_start_pos = 0
if model_config["LLM_MODE"] == "prefill" or device_perf:
generation_length = 1
else:
generation_start_pos = UNIT_TEST_START_POS
generation_length = UNIT_TEST_GENERATION_LENGTH

# Pre-process inputs in prompt mode
Expand Down Expand Up @@ -148,6 +155,9 @@ def run_test_LlamaModel_inference(
)

# TT hardware execution -------------------------------------------------------------
if device_perf:
signpost(DEVICE_PERF_START_SIGNPOST) # start for device perf measurement

tt_inp_emb, start_pos, rot_mat, attn_mask = tt_model.prepare_inputs(tt_inp_ids, start_pos)

tt_out = tt_model(
Expand All @@ -160,6 +170,10 @@ def run_test_LlamaModel_inference(

tt_out = ttnn.from_device(tt_out)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3))

if device_perf:
signpost(DEVICE_PERF_END_SIGNPOST) # end for device perf measurement

tt_out = tt_out[..., : configuration.vocab_size]
tt_out = tt_out.permute(2, 1, 0, 3).squeeze() # [batch, hidden_dim]
if model_config["LLM_MODE"] == "decode":
Expand Down
14 changes: 7 additions & 7 deletions models/demos/t3000/llama2_70b/tests/test_llama_model_t3000.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from models.demos.t3000.llama2_70b.tests.test_llama_model import run_test_LlamaModel_inference


N_LAYERS_TO_PCC = {
1: 0.99,
}


@skip_for_grayskull("Requires eth connected devices to run")
# @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="See GH Issue #10317")
@pytest.mark.parametrize(
Expand All @@ -18,11 +23,7 @@
("llama3"),
),
)
@pytest.mark.parametrize(
"pcc, n_layers",
((0.99, 1),),
ids=("1L",),
)
@pytest.mark.parametrize("n_layers", (1,), ids=("1L",))
@pytest.mark.parametrize(
"batch, seq_len",
((32, 1), (1, 128), (1, 2048), (1, 8192)),
Expand All @@ -42,7 +43,6 @@
def test_LlamaModel_inference(
batch,
seq_len,
pcc,
n_layers,
t3k_mesh_device,
max_batch_size,
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_LlamaModel_inference(
t3k_mesh_device,
batch,
seq_len,
pcc,
N_LAYERS_TO_PCC[n_layers],
model_config,
n_layers,
llama_version,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_model_config(
llm_mode = "decode" if seq_len == 1 else "prefill"
assert num_devices == 8
assert batch in (1, 16, 32)
assert seq_len in (1, 128, 256, 2048, 8192, 32 * 1024, 128 * 1024)
assert seq_len in (1, 128, 256, 2048, 4096, 8192, 32 * 1024, 128 * 1024)

# Supported values, TODO update for larger TT chips
if max_context_len > 4096:
Expand Down
2 changes: 2 additions & 0 deletions models/demos/t3000/mixtral8x7b/scripts/op_perf_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def main():
if args.write_ops_to_csv:
write_blocks_to_csv(blocks, args.write_ops_to_csv)

return tokens_per_s


def read_rows(csv_file):
with open(csv_file, "r") as f:
Expand Down
10 changes: 8 additions & 2 deletions tt_metal/tools/profiler/process_model_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
from tt_metal.tools.profiler.common import PROFILER_OUTPUT_DIR, PROFILER_SCRIPTS_ROOT


def post_process_ops_log(output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False):
def get_latest_ops_log_filename(output_logs_subdir):
runDate = sorted(os.listdir(PROFILER_OUTPUT_DIR / output_logs_subdir))[-1]
df = pd.read_csv(PROFILER_OUTPUT_DIR / output_logs_subdir / runDate / f"ops_perf_results_{runDate}.csv")
filename = PROFILER_OUTPUT_DIR / output_logs_subdir / runDate / f"ops_perf_results_{runDate}.csv"
return filename


def post_process_ops_log(output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False):
filename = get_latest_ops_log_filename(output_logs_subdir)
df = pd.read_csv(filename)

if has_signposts:
# there are explicit start and stop points in the model we want to measure between
Expand Down

0 comments on commit 3dc78e5

Please sign in to comment.