Skip to content

Latest commit

 

History

History
465 lines (358 loc) · 22.6 KB

README.md

File metadata and controls

465 lines (358 loc) · 22.6 KB

Summary

This example demonstrates how to run Llama models on mobile via ExecuTorch. We use XNNPACK to accelerate the performance and 4-bit groupwise quantization to fit the model on a phone.

Here are supported models:

  • Llama 3.2 1B and 3B
  • Llama 3.2 Quantized 1B and 3B
  • Llama 3.1 8B
  • Llama 3 8B
  • Llama 2 7B

Pretrained models are not included in this repo. Users are suggested to download them here.

This page contains the basic recipe for running Llama. See Llama utils page page for more advanced use-cases such as fine-tuning and running smaller models for educational purposes.

What is Llama?

Llama is a collection of large language models that use publicly available data for training. These models are based on the transformer architecture, which allows it to process input sequences of arbitrary length and generate output sequences of variable length. One of the key features of Llama models is its ability to generate coherent and contextually relevant text. This is achieved through the use of attention mechanisms, which allow the model to focus on different parts of the input sequence as it generates output. Additionally, Llama models use a technique called “masked language modeling” to pre-train the model on a large corpus of text, which helps it learn to predict missing words in a sentence.

Llama models have shown to perform well on a variety of natural language processing tasks, including language translation, question answering, and text summarization and are also capable of generating human-like text, making Llama models a useful tool for creative writing and other applications where natural language generation is important.

Overall, Llama models are powerful and versatile language models that can be used for a wide range of natural language processing tasks. The model’s ability to generate coherent and contextually relevant text makes it particularly useful for applications such as chatbots, virtual assistants, and language translation.

Please note that the models are subject to the Llama 2 Acceptable Use Policy, Llama 3 Acceptable Use Policy and Responsible Use Guide.

Results

Llama 3.2 1B/3B and quantized 1B/3B models

For Llama 3.2 1B/3B models, we have enabled the original BF16 format and quantization to 4-bit, using SpinQuant and QAT+LoRA, for enhanced performance.

The quantized models were optimized primarily for Arm CPU architecture by leveraging XNNPACK and Kleidi AI library. Work is underway to specifically enable quantization on mobile accelerators for Llama 1B/3B.

Enablement

We have successfully verified performance on the following devices: iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+, S22 and OnePlus 12 (featuring 16GB RAM).

Note, the Llama 3.2 3B unquantized BF16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements.

Quantization

The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization, using SpinQuant and QAT+LoRA methods.

Our quantization scheme involves three parts, applicable to both methods:

  • We quantize all linear layers in all transformer blocks to a 4-bit groupwise scheme (with a group size of 32) for weights and 8-bit per-token dynamic quantization for activations.
  • The classification layer is quantized to 8-bit per-channel for weight and 8-bit per token dynamic quantization for activation.
  • We employ an 8-bit per channel quantization for embedding.

We use torchao library APIs to define these schemes.

SpinQuant

The SpinQuant method takes the original weights and produces optimized quantized weights with minimal outliers, resulting in higher accuracy. This can be achieved without any finetuning of the weights and only requires 100 iterations on a single A100 node.

SpinQuant can generate quantized weights that are compatible with ExecuTorch, specifically, it can be integrated with the existing optimized XNNPACK kernels (e.g., group-wise 4bit weight and 8bit dynamic activation). This allows developers to benefit from the higher accuracy of SpinQuant while also taking advantage of the strong performance of ExecuTorch acceleration.

Quantization-Aware Training and LoRA (QAT+LoRA)

Quantization-Aware Training (QAT) is employed to simulate the effects of quantization during the training of Llama-3.2 models, enabling optimization of their performance in low precision environments. To initialize QAT, BF16 Llama-3.2 model checkpoints obtained after supervised fine-tuning (SFT) are utilized and an additional full round of SFT training with QAT is performed. The backbone of the QAT model is then frozen and another round of SFT is performed with low-rank adaptation (LoRA) adaptors applied to all layers within the transformer block. Meanwhile, the LoRA adaptors' weights and activations are maintained in BF16.

Accuracy

Please see the Llama 3.2 model card for accuracy evalations.

Performance

Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an adb binary-based approach with prompt length of 64. It is measured with KleidiAI library. KleidiAI is not enabled by default yet. Use -DEXECUTORCH_XNNPACK_ENABLE_KLEIDI=ON to enable it in the build.

Model Decode (tokens/s) Time-to-first-token (sec) Prefill (tokens/s) Model size (PTE file size in MiB) Memory size (RSS in MiB)
1B BF16 (baseline) 19.2 1.0 60.3 2,358 3,185
1B SpinQuant 50.2 (2.6x) 0.3 (-76.9%) 260.5 (4.3x) 1,083 (-54.1%) 1,921 (-39.7%)
1B QAT+LoRA 45.8 (2.4x) 0.3 (-76.0%) 252.0 (4.2x) 1,127 (-52.2%) 2,255 (-29.2%)
3B BF16 (baseline) 7.6 3.0 21.2 6,129 7,419
3B SpinQuant 19.7 (2.6x) 0.7 (-76.4%) 89.7 (4.2x) 2,435 (-60.3%) 3,726 (-49.8%)
3B QAT+LoRA 18.5 (2.4x) 0.7 (-76.1%) 88.8 (4.2x) 2,529 (-58.7%) 4,060 (-45.3%)

Llama3.2 1B, unquantized, BF16 on Android phone.

Llama3.2 3B, 4bit quantized (SpinQuant) on Android phone

Llama 3/3.1 8B

Since Llama 3 8B model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized (PTQ) model.

Enablement

For Llama 3 8B and Llama3.1 8B, we have verified so far on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+ and OnePlus 12 (with 16GB RAM) by quantizing to 4bit.

Quantization

We employed PTQ 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. Due to Llama3's vocabulary size, we had to quantize embedding lookup table as well. For these results embedding lookup table was groupwise quantized with 4-bits and group size of 32.

We use torchao library APIs to define these schemes.

Accuracy

We evaluated WikiText perplexity using LM Eval. Below are the results for two different groupsizes, with max_seq_length 2048, and limit 1000.

Model Baseline (FP32) Groupwise 4-bit (128) Groupwise 4-bit (256)
Llama 3 8B 7.9 9.4 9.7

Please note that LM Eval reports perplexity normalized by word count instead of token count. You may see different perplexity for WikiText from other sources if they implement it differently. More details could be found here.

Performance

Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an adb binary-based approach.

Device Groupwise 4-bit (128) Groupwise 4-bit (256)
Galaxy S22 7.85 tokens/second 8.4 tokens/second
Galaxy S24 10.91 tokens/second 11.21 tokens/second
OnePlus 12 10.85 tokens/second 11.02 tokens/second



Llama3.1 8B, 4bit quantized on Android phone

Please visit this section to try it on non-CPU backend, including CoreML, MPS, Qualcomm HTP or MediaTek.

Instructions

Tested on

  • MacOS M1/M2, Linux.
  • For Llama 3 8B, your device may require at least 32GB RAM. If this is a constraint for you, please try the smaller stories model.

Step 1: Setup

⚠️ double check your python environment: make sure conda activate <VENV> is run before all the bash and python scripts.

  1. Follow the tutorial to set up ExecuTorch. For installation run ./install_requirements.sh --pybind xnnpack
  2. Run examples/models/llama/install_requirements.sh to install a few dependencies.

Step 2: Prepare model

Option A: Download and export Llama3.2 1B/3B model.

  1. Download consolidated.00.pth, params.json and tokenizer.model from Llama website or Hugging Face. For chat use-cases, download the instruct models.

  2. Export model and generate .pte file.

  • Use original BF16 version, without any quantization.
# No quantization
# Set these paths to point to the downloaded files
LLAMA_CHECKPOINT=path/to/checkpoint.pth
LLAMA_PARAMS=path/to/params.json

python -m examples.models.llama.export_llama \
  --checkpoint "${LLAMA_CHECKPOINT:?}" \
  --params "${LLAMA_PARAMS:?}" \
  -kv \
  --use_sdpa_with_kv_cache \
  -X \
  -d bf16 \
  --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
  --output_name="llama3_2.pte"
  • To use SpinQuant, here are two ways:
    • Download directly from Llama website. The model weights are prequantized and can be exported to pte file directly.
    • Follow its instruction for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint.
# SpinQuant
# Set these paths to point to the exported files
LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
LLAMA_PARAMS=path/to/spinquant/params.json

python -m examples.models.llama.export_llama \
   --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
   --params "${LLAMA_PARAMS:?}" \
   --use_sdpa_with_kv_cache \
   -X \
   --xnnpack-extended-ops \
   --preq_mode 8da4w_output_8da8w \
   --preq_group_size 32 \
   --max_seq_length 2048 \
   --output_name "llama3_2.pte" \
   -kv \
   -d fp32 \
   --preq_embedding_quantize 8,0 \
   --use_spin_quant native \
   --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
  • To use QAT+LoRA, download directly from Llama website. The model weights are prequantized and can be exported to pte file directly by:
# QAT+LoRA
# Set these paths to point to the exported files
LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
LLAMA_PARAMS=path/to/qlora/params.json

python -m examples.models.llama.export_llama \
   --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
   --params "${LLAMA_PARAMS:?}" \
   -qat \
   -lora 16 \
   --preq_mode 8da4w_output_8da8w \
   --preq_group_size 32 \
   --preq_embedding_quantize 8,0 \
   --use_sdpa_with_kv_cache \
   -kv \
   -X \
   --xnnpack-extended-ops \
   -d fp32 \
   --max_seq_length 2048 \
   --output_name "llama3_2.pte" \
   --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

Option B: Download and export Llama 3 8B instruct model

You can export and run the original Llama 3 8B instruct model.

  1. Llama 3 pretrained parameters can be downloaded from Meta's official Llama 3 repository.

  2. Export model and generate .pte file

    python -m examples.models.llama.export_llama \
        --checkpoint <consolidated.00.pth> \
    	-p <params.json> \
    	-kv \
    	--use_sdpa_with_kv_cache \
    	-X \
    	-qmode 8da4w \
    	--group_size 128 \
    	-d fp32 \
    	--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
    	--embedding-quantize 4,32 \
    	--output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
    

    Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with --embedding-quantize 4,32 as shown above to further reduce the model size.

    If you're interested in deploying on non-CPU backends, please refer the non-cpu-backend section

Step 3: Run on your computer to validate

  1. Build executorch with optimized CPU performance as follows. Build options available here.
    cmake -DPYTHON_EXECUTABLE=python \
        -DCMAKE_INSTALL_PREFIX=cmake-out \
        -DEXECUTORCH_ENABLE_LOGGING=1 \
        -DCMAKE_BUILD_TYPE=Release \
        -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
        -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
        -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
        -DEXECUTORCH_BUILD_XNNPACK=ON \
        -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
        -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
        -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
        -Bcmake-out .
    
    cmake --build cmake-out -j16 --target install --config Release
    

Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions.

  1. Build llama runner.

    cmake -DPYTHON_EXECUTABLE=python \
        -DCMAKE_INSTALL_PREFIX=cmake-out \
        -DCMAKE_BUILD_TYPE=Release \
        -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
        -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
        -DEXECUTORCH_BUILD_XNNPACK=ON \
        -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
        -Bcmake-out/examples/models/llama \
        examples/models/llama
    
    cmake --build cmake-out/examples/models/llama -j16 --config Release
    
  2. Run model. Run options available here.

    cmake-out/examples/models/llama/llama_main --model_path=<model pte file> --tokenizer_path=<tokenizer.model> --prompt=<prompt>
    

To build for CoreML backend and validate on Mac, replace -DEXECUTORCH_BUILD_XNNPACK=ON with -DEXECUTORCH_BUILD_COREML=ON

Step 4: Run benchmark on Android phone

1. Build llama runner binary for Android

Pre-requisite: Android NDK (tested with r27b) which can be downloaded from here. Note that the mac binary can be unpackaged and you can locate NDK folder from it.

1.1 Set Android NDK

export ANDROID_NDK=<path-to-android-ndk>

1.2 Build executorch and associated libraries for android.

cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
    -DANDROID_ABI=arm64-v8a \
    -DANDROID_PLATFORM=android-23 \
    -DCMAKE_INSTALL_PREFIX=cmake-out-android \
    -DCMAKE_BUILD_TYPE=Release \
    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
    -DEXECUTORCH_ENABLE_LOGGING=1 \
    -DPYTHON_EXECUTABLE=python \
    -DEXECUTORCH_BUILD_XNNPACK=ON \
    -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
    -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
    -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
    -Bcmake-out-android .

cmake --build cmake-out-android -j16 --target install --config Release

1.2 Build llama runner for android

cmake  -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
    -DANDROID_ABI=arm64-v8a \
    -DANDROID_PLATFORM=android-23 \
    -DCMAKE_INSTALL_PREFIX=cmake-out-android \
    -DCMAKE_BUILD_TYPE=Release \
    -DPYTHON_EXECUTABLE=python \
    -DEXECUTORCH_BUILD_XNNPACK=ON \
    -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
    -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
    -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
    -Bcmake-out-android/examples/models/llama \
    examples/models/llama

cmake --build cmake-out-android/examples/models/llama -j16 --config Release

2. Run on Android via adb shell

Pre-requisite: Make sure you enable USB debugging via developer options on your phone

2.1 Connect your android phone

2.2 Upload model, tokenizer and llama runner binary to phone

adb shell mkdir -p /data/local/tmp/llama
adb push <model.pte> /data/local/tmp/llama/
adb push <tokenizer.model> /data/local/tmp/llama/
adb push cmake-out-android/examples/models/llama/llama_main /data/local/tmp/llama/

2.3 Run model

adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.model> --prompt \"What is the capital of France?\" --seq_len 120" --warmup=1

Step 6: Build Mobile apps

iOS

Please refer to this tutorial to for full instructions on building the iOS LLAMA Demo App. Rename tokenizer.model file to tokenizer.bin because the demo app looks for the tokenizer file with .bin extension.

Android

Please refer to this tutorial to for full instructions on building the Android LLAMA Demo App.

Utility tools for Llama enablement

Evaluate model accuracy

Forewarning: Model evaluation without a GPU may take a long time, especially on larger models.

We use LM Eval to evaluate model accuracy.

For base models, use the following example command to calculate its perplexity based on WikiText.

python -m examples.models.llama.eval_llama \
	-c <checkpoint.pth> \
	-p <params.json> \
	-t <tokenizer.model/bin> \
	-kv \
	-d <checkpoint dtype> \
	--max_seq_len <max sequence length> \
	--limit <number of samples>

For instruct models, use the following example command to calculate its MMLU score.

python -m examples.models.llama.eval_llama \
	-c <checkpoint.pth> \
	-p <params.json> \
	-t <tokenizer.model/bin> \
	-kv \
	-d <checkpoint dtype> \
	--tasks mmlu \
	--num_fewshot 5 \
	--max_seq_len <max sequence length>

See Llama utils page page for more advanced use-cases such as fine-tuning and running smaller models for educational purposes, and quick iteration and verification.

What is coming next?

Quantization

  • Enabling FP16 model to leverage smaller groupsize for 4-bit quantization.
  • Enabling GPTQ for 4-bit groupwise quantization
  • Enabling custom quantization
  • Lower bit quantization

Models

  • Enabling more generative AI models and architectures.

Performance

  • Performance improvement via techniques such as speculative decoding
  • Enabling LLama and other architectures via Vulkan
  • Enabling performant execution of widely used quantization schemes.

Notes

This example tries to reuse the Python code, with minimal modifications to make it compatible with current ExecuTorch:

  1. Since ExecuTorch does not support complex Tensor data type, use the customized functions to have rotary embedding with real numbers. Please see GitHub issue: Support complex data type in ExecuTorch.
  2. No CUDA. ExecuTorch is focused on Edge use cases where CUDA is not available on most of the edge devices.
  3. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch.

Common Issues and Mitigations:

  • To clean your build:
git clean -xfd
pip uninstall executorch
./install_requirements.sh --pybind xnnpack

rm -rf cmake-out
  • If you encounter pthread related issues during link time, add pthread in target_link_libraries in CMakeLists.txt
  • On Mac, if there is linking error in Step 4 with error message like
0  0x100823648  __assert_rtn + 72
1  0x10074bc5c  ld::Fixup::applyFixup(ld::Atom const*, ld::LayoutLinkedImage const&, unsigned char*) const + 8268
2  0x1007de7d8  ___ZN2ld16LayoutExecutable27writeContentWithoutLinkEditENSt3__14spanIhLm18446744073709551615EEEy_block_invoke + 332
3  0x188cca428  _dispatch_client_callout2 + 20
4  0x188cde850  _dispatch_apply_invoke3 + 336
5  0x188cca3e8  _dispatch_client_callout + 20
6  0x188ccbc68  _dispatch_once_callout + 32
7  0x188cdeeec  _dispatch_apply_invoke_and_wait + 372
8  0x188cdde9c  _dispatch_apply_with_attr_f + 1212
9  0x188cde08c  dispatch_apply + 96
10  0x1007de9e4  void mapReduce<ld::Atom const*, mach_o::Error>(std::__1::span<ld::Atom const*, 18446744073709551615ul>, unsigned long, void (unsigned long, mach_o::Error&, std::__1::span<ld::Atom const*, 18446744073709551615ul>) block_pointer, void (std::__1::span<mach_o::Error, 18446744073709551615ul>) block_pointer) + 336
11  0x1007de594  ld::LayoutExecutable::writeContentWithoutLinkEdit(std::__1::span<unsigned char, 18446744073709551615ul>, unsigned long long) + 1180
12  0x1007e4020  ld::LayoutExecutable::writeToFile(char const*) + 15248
13  0x1007962e8  main + 9424
ld: Assertion failed: (extras.otherInstrOffset != 0 && "Kind::arm64_adrp_ldr missing extra info"), function applyFixup, file Fixup.cpp, line 793.
clang: error: linker command failed with exit code 1 (use -v to see invocation)

It's a known issue for Xcode version 15.1. Mitigation: update to most recent Xcode version, clean and rebuild.