generated from oracle-quickstart/oci-quickstart-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from oracle-quickstart/ft-codes
finetuning codes
- Loading branch information
Showing
5 changed files
with
331 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# ML Fine-Tuning With PyTorch & ROCM on AMD Instinct MI300X | ||
|
||
## Overview | ||
|
||
- Procedure for fine-tuning using the LoRA (Low-Rank Adaptation) approach with the Hugging Face library. We will utilize FSDP (Fully Sharded Data Parallel) for efficient distributed training. | ||
- A customized Docker container was utilized, provided by AMD for this specific purpose. However, for those seeking to replicate the environment, a comprehensive list of necessary packages along with detailed setup instructions is available. This will enable you to create a suitable environment for conducting the finetuning exercise independently. | ||
|
||
## Setting up the Virtual Enviroment | ||
### 1. Clone the Repository (Optional) | ||
- If you have not cloned this repo already, please proceed to do it | ||
|
||
```bash | ||
git clone https://github.com/oracle-quickstart/oci-amd-gpu-onboarding.git | ||
cd oci-amd-gpu-onboarding | ||
``` | ||
### 2. Create virtual environment | ||
```bash | ||
python3 -m venv venv_ft | ||
``` | ||
### 3. Activate the virtual enviroment | ||
```bash | ||
source vnenv_ft/bin/activate | ||
``` | ||
|
||
### 4. Install required packages | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
- You are all set to perform fine-tuning! | ||
|
||
## Execute Fine-Tuning | ||
|
||
To start the fine-tuning process, run the following command: | ||
|
||
```bash | ||
ACCELERATE_USE_FSDP=1 torchrun --nproc_per_node=8 ./lora.py --config llama_3_70b_fsdp_lora.yaml | ||
``` | ||
|
||
### Parameters: | ||
- `ACCELERATE_USE_FSDP=1`: Enables Fully Sharded Data Parallelism. | ||
- `torchrun --nproc_per_node=8`: Runs the training across 8 processes (GPUs). | ||
- `./lora.py`: The script used for training. | ||
- `--config llama_3_70b_fsdp_lora.yaml`: Configuration file for the LoRA fine-tuning. | ||
|
||
## Configuration File | ||
|
||
Before running the training script, you need to modify the `llama_3_70b_fsdp_lora.yaml` configuration file. Here are the key parameters to update: | ||
|
||
1. **model_id**: Set this to the directory where your Hugging Face model checkpoints are stored. | ||
```yaml | ||
model_id: "path/to/your/huggingface/model/checkpoints" | ||
``` | ||
2. **max_seq_length**: Update this to the desired sequence length for training. | ||
```yaml | ||
max_seq_length: <your_sequence_length> | ||
``` | ||
3. **per_device_train_batch_size**: Adjust the batch size based on the `max_seq_length`: | ||
- For `max_seq_length=1024`: set `per_device_train_batch_size: 9` | ||
- For `max_seq_length=2048`: set `per_device_train_batch_size: 4` | ||
- For `max_seq_length=4096`: set `per_device_train_batch_size: 1` | ||
|
||
Example: | ||
```yaml | ||
per_device_train_batch_size: 9 # for max_seq_length=1024 | ||
``` | ||
|
||
## Example Configuration | ||
|
||
Here’s an example of how your `llama_3_70b_fsdp_lora.yaml` might look after editing: | ||
|
||
```yaml | ||
model_id: "path/to/your/huggingface/model/checkpoints" | ||
max_seq_length: 1024 | ||
per_device_train_batch_size: 9 | ||
``` | ||
|
||
## Running the Fine-Tuning | ||
|
||
After making the necessary changes to the configuration file, you can execute the fine-tuning command provided above. | ||
|
||
## Monitoring and Logging | ||
|
||
Monitor the training logs for any errors or warnings. Adjust the configuration as needed based on the performance and resource utilization. We used mlflow to track experiement. Here is the link to our [experiement](http://mlflow-benchmarking.corrino-oci.com:5000/#/experiments/63/runs/c836a0297f41440aab97fa55c714d7e4) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# script parameters | ||
model_id: "/data/models/Meta-Llama-3-70B" # Hugging Face model id, | ||
dataset_path: "." # path to dataset, | ||
max_seq_length: 2048 # 2048 # max sequence length for model and packing of the dataset, | ||
# training parameters, | ||
output_dir: "./llama-3-70b-hf-no-robot" # Temporary output directory for model checkpoints, | ||
report_to: "mlflow" # report metrics to tensorboard, | ||
learning_rate: 0.00001 # learning rate 2e-4, | ||
lr_scheduler_type: "constant" # learning rate scheduler, | ||
num_train_epochs: 1 # number of training epochs, | ||
per_device_train_batch_size: 2 # batch size per device during training, | ||
per_device_eval_batch_size: 1 # batch size for evaluation, | ||
gradient_accumulation_steps: 1 # number of steps before performing a backward/update pass, | ||
optim: adamw_torch # use torch adamw optimizer, | ||
logging_steps: 1 # log every 10 steps, | ||
save_strategy: epoch # save checkpoint every epoch, | ||
evaluation_strategy: epoch # evaluate every epoch, | ||
max_grad_norm: 1 # max gradient norm, | ||
warmup_ratio: 0.0 # warmup ratio, | ||
bf16: true # use bfloat16 precision, | ||
tf32: false # use tf32 precision, | ||
gradient_checkpointing: false # use gradient checkpointing to save memory, | ||
# FSDP parameters: https://huggingface.co/docs/transformers/main/en/fsdp, | ||
fsdp: "full_shard auto_wrap" # remove offload if enough GPU memory, | ||
fsdp_config: | ||
backward_prefetch: "backward_pre" | ||
forward_prefetch: "false" | ||
use_orig_params: "false" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import logging | ||
import time | ||
import psutil | ||
from dataclasses import dataclass, field | ||
import os | ||
import random | ||
import torch | ||
from datasets import load_dataset | ||
from transformers import AutoTokenizer, TrainingArguments | ||
from trl.commands.cli_utils import TrlParser | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
set_seed, | ||
) | ||
from trl import setup_chat_format | ||
from peft import LoraConfig | ||
from trl import SFTTrainer | ||
import mlflow | ||
from accelerate import Accelerator | ||
|
||
# Anthropic/Vicuna like template without the need for special tokens | ||
LLAMA_3_CHAT_TEMPLATE = ( | ||
"{% for message in messages %}" | ||
"{% if message['role'] == 'system' %}" | ||
"{{ message['content'] }}" | ||
"{% elif message['role'] == 'user' %}" | ||
"{{ '\n\nHuman: ' + message['content'] + eos_token }}" | ||
"{% elif message['role'] == 'assistant' %}" | ||
"{{ '\n\nAssistant: ' + message['content'] + eos_token }}" | ||
"{% endif %}" | ||
"{% endfor %}" | ||
"{% if add_generation_prompt %}" | ||
"{{ '\n\nAssistant: ' }}" | ||
"{% endif %}" | ||
) | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
dataset_path: str = field( | ||
default=None, | ||
metadata={"help": "Path to the dataset"}, | ||
) | ||
model_id: str = field( | ||
default=None, metadata={"help": "Model ID to use for SFT training"} | ||
) | ||
max_seq_length: int = field( | ||
default=2048, metadata={"help": "The maximum sequence length for SFT Trainer"} | ||
) | ||
|
||
|
||
def b2mb(x): | ||
return int(x / 2**20) | ||
|
||
|
||
class TorchTracemalloc: | ||
def __enter__(self): | ||
self.begin = torch.cuda.memory_allocated() | ||
self.process = psutil.Process() | ||
self.cpu_begin = self.process.memory_info().rss | ||
return self | ||
|
||
def __exit__(self, *exc): | ||
self.end = torch.cuda.memory_allocated() | ||
self.peak = torch.cuda.max_memory_allocated() | ||
self.used = b2mb(self.end - self.begin) | ||
self.peaked = b2mb(self.peak - self.begin) | ||
self.cpu_end = self.process.memory_info().rss | ||
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) | ||
|
||
|
||
def training_function(script_args, training_args): | ||
accelerator = Accelerator() | ||
|
||
if accelerator.is_main_process: | ||
mlflow.set_tracking_uri("") # add your mlflow endpoint | ||
experiment_name = "Mi300x_llama3_70B_scrolls_govt_report_experimemt_FSDP_1" | ||
mlflow.set_experiment(experiment_name) | ||
mlflow.start_run() | ||
mlflow.log_params( | ||
{ | ||
"model_id": script_args.model_id, | ||
"dataset": "tau/scrolls", | ||
"subset": "gov_report", | ||
"max_seq_length": script_args.max_seq_length, | ||
} | ||
) | ||
|
||
train_dataset = load_dataset("tau/scrolls", "gov_report", split="train") | ||
test_dataset = load_dataset("tau/scrolls", "gov_report", split="test") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE | ||
|
||
def template_dataset(examples): | ||
return {"text": examples["input"]} | ||
|
||
train_dataset = train_dataset.map(template_dataset, remove_columns=["input"]) | ||
test_dataset = test_dataset.map(template_dataset, remove_columns=["input"]) | ||
|
||
with training_args.main_process_first( | ||
desc="Log a few random samples from the processed training set" | ||
): | ||
for index in random.sample(range(len(train_dataset)), 2): | ||
print(train_dataset[index]["text"]) | ||
|
||
torch_dtype = torch.bfloat16 | ||
quant_storage_dtype = torch.bfloat16 | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_id, | ||
attn_implementation="flash_attention_2", | ||
torch_dtype=quant_storage_dtype, | ||
use_cache=False if training_args.gradient_checkpointing else True, | ||
) | ||
|
||
if training_args.gradient_checkpointing: | ||
model.gradient_checkpointing_enable() | ||
|
||
peft_config = LoraConfig( | ||
lora_alpha=32, | ||
lora_dropout=0.05, | ||
r=8, | ||
bias="none", | ||
target_modules="all-linear", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
trainer = SFTTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
dataset_text_field="text", | ||
eval_dataset=test_dataset, | ||
peft_config=peft_config, | ||
max_seq_length=script_args.max_seq_length, | ||
tokenizer=tokenizer, | ||
packing=True, | ||
dataset_kwargs={ | ||
"add_special_tokens": False, | ||
"append_concat_token": False, | ||
}, | ||
) | ||
if trainer.accelerator.is_main_process: | ||
trainer.model.print_trainable_parameters() | ||
|
||
train_start_time = time.time() | ||
|
||
for epoch in range(int(training_args.num_train_epochs)): | ||
with TorchTracemalloc() as tracemalloc: | ||
train_output = trainer.train() | ||
|
||
train_loss = train_output.training_loss | ||
train_runtime = train_output.metrics["train_runtime"] | ||
train_samples_per_second = train_output.metrics["train_samples_per_second"] | ||
train_steps_per_second = train_output.metrics["train_steps_per_second"] | ||
|
||
eval_output = trainer.evaluate() | ||
eval_loss = eval_output["eval_loss"] | ||
eval_perplexity = eval_output["eval_perplexity"] | ||
|
||
if accelerator.is_main_process: | ||
mlflow.log_metrics( | ||
{ | ||
"epoch": epoch + 1, | ||
"train_loss": train_loss, | ||
"train_runtime": train_runtime, | ||
"train_samples_per_second": train_samples_per_second, | ||
"train_steps_per_second": train_steps_per_second, | ||
"eval_loss": eval_loss, | ||
"eval_perplexity": eval_perplexity, | ||
"gpu_memory_used": tracemalloc.used, | ||
"gpu_peak_memory": tracemalloc.peaked, | ||
"cpu_memory_used": tracemalloc.cpu_used, | ||
}, | ||
step=epoch + 1, | ||
) | ||
|
||
total_train_runtime = time.time() - train_start_time | ||
|
||
if accelerator.is_main_process: | ||
mlflow.log_metrics( | ||
{ | ||
"total_train_runtime": total_train_runtime, | ||
"final_train_loss": train_loss, | ||
"final_eval_loss": eval_loss, | ||
"final_eval_perplexity": eval_perplexity, | ||
} | ||
) | ||
mlflow.end_run() | ||
|
||
if trainer.is_fsdp_enabled: | ||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = TrlParser((ScriptArguments, TrainingArguments)) | ||
script_args, training_args = parser.parse_args_and_config() | ||
|
||
if training_args.gradient_checkpointing: | ||
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} | ||
set_seed(training_args.seed) | ||
|
||
training_function(script_args, training_args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
psutil | ||
torch | ||
datasets | ||
transformers | ||
trl | ||
peft | ||
mlflow | ||
accelerate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# example running | ||
ACCELERATE_USE_FSDP=1 torchrun --nproc_per_node=8 ./lora.py --config llama_3_70b_fsdp_lora.yaml | ||
|