Skip to content

Commit

Permalink
Merge pull request #1 from oracle-quickstart/ft-codes
Browse files Browse the repository at this point in the history
 finetuning codes
  • Loading branch information
amgowda-oci authored Jan 21, 2025
2 parents 69312e1 + f5c007a commit 739b07d
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 0 deletions.
86 changes: 86 additions & 0 deletions MI300X/single-node-finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# ML Fine-Tuning With PyTorch & ROCM on AMD Instinct MI300X

## Overview

- Procedure for fine-tuning using the LoRA (Low-Rank Adaptation) approach with the Hugging Face library. We will utilize FSDP (Fully Sharded Data Parallel) for efficient distributed training.
- A customized Docker container was utilized, provided by AMD for this specific purpose. However, for those seeking to replicate the environment, a comprehensive list of necessary packages along with detailed setup instructions is available. This will enable you to create a suitable environment for conducting the finetuning exercise independently.

## Setting up the Virtual Enviroment
### 1. Clone the Repository (Optional)
- If you have not cloned this repo already, please proceed to do it

```bash
git clone https://github.com/oracle-quickstart/oci-amd-gpu-onboarding.git
cd oci-amd-gpu-onboarding
```
### 2. Create virtual environment
```bash
python3 -m venv venv_ft
```
### 3. Activate the virtual enviroment
```bash
source vnenv_ft/bin/activate
```

### 4. Install required packages
```bash
pip install -r requirements.txt
```
- You are all set to perform fine-tuning!

## Execute Fine-Tuning

To start the fine-tuning process, run the following command:

```bash
ACCELERATE_USE_FSDP=1 torchrun --nproc_per_node=8 ./lora.py --config llama_3_70b_fsdp_lora.yaml
```

### Parameters:
- `ACCELERATE_USE_FSDP=1`: Enables Fully Sharded Data Parallelism.
- `torchrun --nproc_per_node=8`: Runs the training across 8 processes (GPUs).
- `./lora.py`: The script used for training.
- `--config llama_3_70b_fsdp_lora.yaml`: Configuration file for the LoRA fine-tuning.

## Configuration File

Before running the training script, you need to modify the `llama_3_70b_fsdp_lora.yaml` configuration file. Here are the key parameters to update:

1. **model_id**: Set this to the directory where your Hugging Face model checkpoints are stored.
```yaml
model_id: "path/to/your/huggingface/model/checkpoints"
```
2. **max_seq_length**: Update this to the desired sequence length for training.
```yaml
max_seq_length: <your_sequence_length>
```
3. **per_device_train_batch_size**: Adjust the batch size based on the `max_seq_length`:
- For `max_seq_length=1024`: set `per_device_train_batch_size: 9`
- For `max_seq_length=2048`: set `per_device_train_batch_size: 4`
- For `max_seq_length=4096`: set `per_device_train_batch_size: 1`

Example:
```yaml
per_device_train_batch_size: 9 # for max_seq_length=1024
```

## Example Configuration

Here’s an example of how your `llama_3_70b_fsdp_lora.yaml` might look after editing:

```yaml
model_id: "path/to/your/huggingface/model/checkpoints"
max_seq_length: 1024
per_device_train_batch_size: 9
```

## Running the Fine-Tuning

After making the necessary changes to the configuration file, you can execute the fine-tuning command provided above.

## Monitoring and Logging

Monitor the training logs for any errors or warnings. Adjust the configuration as needed based on the performance and resource utilization. We used mlflow to track experiement. Here is the link to our [experiement](http://mlflow-benchmarking.corrino-oci.com:5000/#/experiments/63/runs/c836a0297f41440aab97fa55c714d7e4)

28 changes: 28 additions & 0 deletions MI300X/single-node-finetuning/llama_3_70b_fsdp_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# script parameters
model_id: "/data/models/Meta-Llama-3-70B" # Hugging Face model id,
dataset_path: "." # path to dataset,
max_seq_length: 2048 # 2048 # max sequence length for model and packing of the dataset,
# training parameters,
output_dir: "./llama-3-70b-hf-no-robot" # Temporary output directory for model checkpoints,
report_to: "mlflow" # report metrics to tensorboard,
learning_rate: 0.00001 # learning rate 2e-4,
lr_scheduler_type: "constant" # learning rate scheduler,
num_train_epochs: 1 # number of training epochs,
per_device_train_batch_size: 2 # batch size per device during training,
per_device_eval_batch_size: 1 # batch size for evaluation,
gradient_accumulation_steps: 1 # number of steps before performing a backward/update pass,
optim: adamw_torch # use torch adamw optimizer,
logging_steps: 1 # log every 10 steps,
save_strategy: epoch # save checkpoint every epoch,
evaluation_strategy: epoch # evaluate every epoch,
max_grad_norm: 1 # max gradient norm,
warmup_ratio: 0.0 # warmup ratio,
bf16: true # use bfloat16 precision,
tf32: false # use tf32 precision,
gradient_checkpointing: false # use gradient checkpointing to save memory,
# FSDP parameters: https://huggingface.co/docs/transformers/main/en/fsdp,
fsdp: "full_shard auto_wrap" # remove offload if enough GPU memory,
fsdp_config:
backward_prefetch: "backward_pre"
forward_prefetch: "false"
use_orig_params: "false"
206 changes: 206 additions & 0 deletions MI300X/single-node-finetuning/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import logging
import time
import psutil
from dataclasses import dataclass, field
import os
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments
from trl.commands.cli_utils import TrlParser
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
set_seed,
)
from trl import setup_chat_format
from peft import LoraConfig
from trl import SFTTrainer
import mlflow
from accelerate import Accelerator

# Anthropic/Vicuna like template without the need for special tokens
LLAMA_3_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ message['content'] }}"
"{% elif message['role'] == 'user' %}"
"{{ '\n\nHuman: ' + message['content'] + eos_token }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '\n\nAssistant: ' + message['content'] + eos_token }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '\n\nAssistant: ' }}"
"{% endif %}"
)


@dataclass
class ScriptArguments:
dataset_path: str = field(
default=None,
metadata={"help": "Path to the dataset"},
)
model_id: str = field(
default=None, metadata={"help": "Model ID to use for SFT training"}
)
max_seq_length: int = field(
default=2048, metadata={"help": "The maximum sequence length for SFT Trainer"}
)


def b2mb(x):
return int(x / 2**20)


class TorchTracemalloc:
def __enter__(self):
self.begin = torch.cuda.memory_allocated()
self.process = psutil.Process()
self.cpu_begin = self.process.memory_info().rss
return self

def __exit__(self, *exc):
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.used = b2mb(self.end - self.begin)
self.peaked = b2mb(self.peak - self.begin)
self.cpu_end = self.process.memory_info().rss
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)


def training_function(script_args, training_args):
accelerator = Accelerator()

if accelerator.is_main_process:
mlflow.set_tracking_uri("") # add your mlflow endpoint
experiment_name = "Mi300x_llama3_70B_scrolls_govt_report_experimemt_FSDP_1"
mlflow.set_experiment(experiment_name)
mlflow.start_run()
mlflow.log_params(
{
"model_id": script_args.model_id,
"dataset": "tau/scrolls",
"subset": "gov_report",
"max_seq_length": script_args.max_seq_length,
}
)

train_dataset = load_dataset("tau/scrolls", "gov_report", split="train")
test_dataset = load_dataset("tau/scrolls", "gov_report", split="test")

tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE

def template_dataset(examples):
return {"text": examples["input"]}

train_dataset = train_dataset.map(template_dataset, remove_columns=["input"])
test_dataset = test_dataset.map(template_dataset, remove_columns=["input"])

with training_args.main_process_first(
desc="Log a few random samples from the processed training set"
):
for index in random.sample(range(len(train_dataset)), 2):
print(train_dataset[index]["text"])

torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16

model = AutoModelForCausalLM.from_pretrained(
script_args.model_id,
attn_implementation="flash_attention_2",
torch_dtype=quant_storage_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)

if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()

peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.05,
r=8,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
dataset_text_field="text",
eval_dataset=test_dataset,
peft_config=peft_config,
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs={
"add_special_tokens": False,
"append_concat_token": False,
},
)
if trainer.accelerator.is_main_process:
trainer.model.print_trainable_parameters()

train_start_time = time.time()

for epoch in range(int(training_args.num_train_epochs)):
with TorchTracemalloc() as tracemalloc:
train_output = trainer.train()

train_loss = train_output.training_loss
train_runtime = train_output.metrics["train_runtime"]
train_samples_per_second = train_output.metrics["train_samples_per_second"]
train_steps_per_second = train_output.metrics["train_steps_per_second"]

eval_output = trainer.evaluate()
eval_loss = eval_output["eval_loss"]
eval_perplexity = eval_output["eval_perplexity"]

if accelerator.is_main_process:
mlflow.log_metrics(
{
"epoch": epoch + 1,
"train_loss": train_loss,
"train_runtime": train_runtime,
"train_samples_per_second": train_samples_per_second,
"train_steps_per_second": train_steps_per_second,
"eval_loss": eval_loss,
"eval_perplexity": eval_perplexity,
"gpu_memory_used": tracemalloc.used,
"gpu_peak_memory": tracemalloc.peaked,
"cpu_memory_used": tracemalloc.cpu_used,
},
step=epoch + 1,
)

total_train_runtime = time.time() - train_start_time

if accelerator.is_main_process:
mlflow.log_metrics(
{
"total_train_runtime": total_train_runtime,
"final_train_loss": train_loss,
"final_eval_loss": eval_loss,
"final_eval_perplexity": eval_perplexity,
}
)
mlflow.end_run()

if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")


if __name__ == "__main__":
parser = TrlParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_and_config()

if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
set_seed(training_args.seed)

training_function(script_args, training_args)
8 changes: 8 additions & 0 deletions MI300X/single-node-finetuning/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
psutil
torch
datasets
transformers
trl
peft
mlflow
accelerate
3 changes: 3 additions & 0 deletions MI300X/single-node-finetuning/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# example running
ACCELERATE_USE_FSDP=1 torchrun --nproc_per_node=8 ./lora.py --config llama_3_70b_fsdp_lora.yaml

0 comments on commit 739b07d

Please sign in to comment.