Skip to content

Commit

Permalink
enable benchmark script (#1554)
Browse files Browse the repository at this point in the history
* enable benchmark script

Signed-off-by: jiqing-feng <[email protected]>

* Small fixes to non_cuda_backends.mdx

---------

Signed-off-by: jiqing-feng <[email protected]>
Co-authored-by: Titus <[email protected]>
  • Loading branch information
jiqing-feng and Titus-von-Koeller authored Mar 4, 2025
1 parent 2640753 commit c66e137
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 5 deletions.
67 changes: 67 additions & 0 deletions benchmarking/generation_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse

import torch
import torch.utils.benchmark as benchmark
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name", default="meta-llama/Llama-3.1-8B-Instruct", required=False, type=str, help="model_name"
)
parser.add_argument("--quant_type", default="int8", type=str, help="quant type", choices=["int8", "nf4", "fp4"])
parser.add_argument("--device_map", default="cpu", type=str, help="device_map", choices=["cpu", "xpu", "cuda"])
args = parser.parse_args()

model_name = args.model_name
device_map = args.device_map
if args.quant_type == "int8":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=args.quant_type,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))


# benchmark the performance
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(2):
f(*args, **kwargs)

t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return t0.blocked_autorange().mean


MAX_NEW_TOKENS = 100

quantized_model_latency = benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16)
bf16_model_latency = benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)

print(f"bnb model latency: {quantized_model_latency:.3f}")
print(f"bf16 model latency: {bf16_model_latency:.3f}")
print(f"BNB vs. bf16 model speed-up: {(bf16_model_latency / quantized_model_latency):.3f}")

print(f"BNB model memory: {(quantized_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
print(f"bf16 model memory: {(bf16_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
print(
f"BNB vs. bf16 model memory ratio: {(bf16_model.get_memory_footprint() / quantized_model.get_memory_footprint()):.3f}"
)
17 changes: 12 additions & 5 deletions docs/source/non_cuda_backends.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ Thank you for your support!

### Intel

The following performance data is collected from Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf).
The below performance data is collected from the Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).

You may run `benchmarking/generation_benchmark.py` to reproduce the below model memory and inference results. Please note that you need to bind cores if you are using the CPU to benchmark. For example, run `numactl -C 0-55 -m 0 python generation_benchmark.py --quant_type nf4` on Intel 4th Gen Xeon with single socket.

The finetune results are selected from [peft](https://github.com/huggingface/peft/blob/main/examples/olora_finetuning/olora_finetuning.py).

#### Model memory (CPU)
| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Memory (GB) | 15.0 | 8.5 | 5.2 | 5.2 |

#### Inference (CPU)

| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Speed-Up (vs BF16) | 1.0x | 0.44x | 1.8x | 0.1x |
| Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 |
| Speed-Up (vs BF16) | 1.0x | 0.57x | 2.6x | 0.1x |

#### Fine-Tuning (CPU)

| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Speed-Up (vs BF16) | 1.0x | 0.38x | 0.1x | 0.1x |
| Memory (GB) | 40 | 9 | 6.6 | 6.6 |
| Speed-Up (vs BF16) | 1.0x | 0.91x | 1.0x | 1.0x |

0 comments on commit c66e137

Please sign in to comment.