Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Profiler improvements (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson authored Jul 9, 2024
1 parent 537957c commit ccaba64
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 62 deletions.
137 changes: 84 additions & 53 deletions examples/offline_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@

BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256
MAX_SEQ_LEN_DEFAULT = 1024
OUTPUT_LEN_DEFAULT = 2


@dataclass
class ProfileContext:
model: str
tokenizer: str
model_revision: str
sparsity: str
quantization: str
max_seq_len: int
max_model_len: int
max_num_batched_tokens: int
prompt_len: int
output_len: int
batch_size: int
dtype: str
tensor_parallel_size: int
allow_cuda_graphs: bool

Expand All @@ -38,26 +40,29 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
# Create sampling params
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=8,
max_tokens=context.output_len,
ignore_eos=True)

# Sparsity is in the future
# Create LLM
llm = LLM(
model=context.model,
revision=context.model_revision,
sparsity=context.sparsity,
enforce_eager=not context.allow_cuda_graphs,
tensor_parallel_size=context.tensor_parallel_size,
gpu_memory_utilization=0.9,
max_model_len=context.max_seq_len,
quantization=context.quantization,
max_num_batched_tokens=context.max_num_batched_tokens,
)
llm = LLM(model=context.model,
tokenizer=context.tokenizer
if context.tokenizer is not None else context.model,
revision=context.model_revision,
enforce_eager=not context.allow_cuda_graphs,
tensor_parallel_size=context.tensor_parallel_size,
gpu_memory_utilization=0.9,
max_model_len=context.max_model_len,
quantization=context.quantization,
dtype=context.dtype,
max_num_batched_tokens=context.max_num_batched_tokens)

batch_size = context.batch_size
prompt_len = context.prompt_len
output_len = context.output_len

scheduler_config = llm.llm_engine.scheduler_config
max_model_len = llm.llm_engine.model_config.max_model_len
max_num_batched_tokens = scheduler_config.max_num_batched_tokens
max_num_seqs = scheduler_config.max_num_seqs

Expand All @@ -75,6 +80,15 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
f"single profile step, please choose a smaller batch size")
sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len)
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
print(
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
f"{output_len} = {prompt_len + output_len}) is larger than the "
f"model's max_model_len ({max_model_len}), please choose a smaller "
f"prompt_len or output_len, or increase --max-model-len")
sys.exit(-1)

for i in range(batch_size):
prompt_token_ids = torch.randint(
Expand All @@ -89,50 +103,59 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
with nm_profile() as prefill_prof:
llm.llm_engine.step() # First step is prefill

with nm_profile() as decode_prof:
llm.llm_engine.step()
decode_results_list = []
for _ in range(context.output_len - 1):
with nm_profile() as decode_prof:
llm.llm_engine.step()
decode_results_list.append(decode_prof.results)

prefill_results = prefill_prof.results
decode_results = decode_prof.results
has_decode = len(decode_results_list) > 0

print("=" * 80)
print(f"= Prefill Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * 80)
print()
prefill_results.print_model_table()
print()
print("=" * 80)
print(f"= Decode Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * 80)
print()
decode_results.print_model_table()

if has_decode:
print()
print("=" * 80)
print(f"= First Decode Step Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * 80)
print()
decode_results_list[0].print_model_table()

print()
print("=" * 80)
print(f"= Prefill Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * 80)
print()
prefill_results.print_summary_table()
print()
print("=" * 80)
print(f"= Decode Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * 80)
print()
decode_results.print_summary_table()
if has_decode:
print()
print("=" * 80)
print(f"= First Decode Step Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * 80)
print()
decode_results_list[0].print_summary_table()

if csv_output:
csv_filename_base = csv_output.rstrip(".csv")
prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv")
prefill_results.export_summary_stats_table_csv(
csv_filename_base + "_prefill_summary_table.csv")
decode_results.export_model_stats_table_csv(\
csv_filename_base + "_decode_model_table.csv")
decode_results.export_summary_stats_table_csv(
csv_filename_base + "_decode_summary_table.csv")

if has_decode:
decode_results_list[0].export_model_stats_table_csv(\
csv_filename_base + "_decode_model_table.csv")
decode_results_list[0].export_summary_stats_table_csv(
csv_filename_base + "_decode_summary_table.csv")

if json_output:
cuda_devices = [
Expand All @@ -149,9 +172,12 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
**asdict(context)
},
"prefill": prefill_results.convert_stats_to_dict(),
"decode": decode_results.convert_stats_to_dict()
}

if has_decode:
for idx, dr in enumerate(decode_results_list):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

with open(json_output.rstrip(".json") + ".json", "w+") as f:
json.dump(json_dict, f, indent=2)
pass
Expand All @@ -165,6 +191,11 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
type=str,
required=True,
help='The name or path of a HuggingFace Transformers model.')
parser.add_argument("--tokenizer",
type=str,
default=None,
help="path to the tokenizer")

parser.add_argument("--model-revision", type=str, default=None)
parser.add_argument(
"--csv",
Expand All @@ -180,29 +211,23 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
type=str,
default=None,
help="Export the results as a json file. This should be the filename")
parser.add_argument(
"--sparsity",
"-s",
type=str,
choices=[None, 'sparse_w16a16', 'semi_structured_sparse_w16a16'],
help="Method used to compress sparse weights. If "
"None, we first check the `sparsity_config` attribute"
"in the model config file. If that is None we assume"
"the model weights are dense")
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=['awq', 'gptq', 'squeezellm', 'marlin', None],
choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None],
default=None,
help="The method used to quantize the model weights, "
"options are \"marlin\", \"awq\", \"gptq\" and \"squeezellm\"")
help="The method used to quantize the model weights, options are "
"\"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"")
parser.add_argument("--dtype",
type=str,
default='auto',
help="model dtype")
parser.add_argument(
"--max-seq-len",
"--max-model-len",
type=int,
default=MAX_SEQ_LEN_DEFAULT,
help=f"Maximum length of a sequence (including prompt and output), "
f"default={MAX_SEQ_LEN_DEFAULT}")
default=None,
help="Maximum length of a sequence (including prompt and output)")
parser.add_argument(
"--max-num-batched-tokens",
type=int,
Expand All @@ -216,6 +241,12 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
default=PROMPT_LEN_DEFAULT,
help=f"Length of the random prompt to use when profiling, all batched "
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
parser.add_argument(
"--output-len",
type=int,
default=OUTPUT_LEN_DEFAULT,
help=
f"Number of output decode steps to run, default={OUTPUT_LEN_DEFAULT}")
parser.add_argument("--batch-size",
type=int,
default=BATCH_SIZE_DEFAULT,
Expand Down
2 changes: 1 addition & 1 deletion neuralmagic/tools/profiler/print_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_entries(node, curr_depth=0):
"examples/offline_profile.py")
parser.add_argument("--phase",
type=str,
choices=["prefill", "decode"],
choices=["prefill", "decode_1"],
required=True,
help="The phase to print the table for.")
parser.add_argument("--table",
Expand Down
15 changes: 7 additions & 8 deletions neuralmagic/tools/profiler/visualize_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_entries_at_depth(depth,

for root in profile_data["prefill"]["summary_stats"]:
get_entries_at_depth(depth, prefill_entries_and_traces, root)
for root in profile_data["decode"]["summary_stats"]:
for root in profile_data["decode_1"]["summary_stats"]:
get_entries_at_depth(depth, decode_entries_and_traces, root)

def attempt_to_make_names_unique(entries_and_traces):
Expand Down Expand Up @@ -199,12 +199,11 @@ def plot_metric(metric: str, ax, add_totals=False):
shorten_plot_legend_strings(legend, 50)

context = profile_data["context"]
plt.suptitle(
f"{context['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['tensor_parallel_size']}"
f"{', Sparsity ' + context['sparsity'] if context['sparsity'] else ''}"
)
sparsity = context.get('sparsity', None)
plt.suptitle(f"{context['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}")
plt.savefig(output, bbox_inches='tight')
print("Created: ", output)

0 comments on commit ccaba64

Please sign in to comment.