Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Feb 5, 2025
1 parent aa390dc commit a7f9152
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
7 changes: 3 additions & 4 deletions notebooks/text-classification/fine_tune_bert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,12 @@
"metadata": {},
"outputs": [],
"source": [
"!torchrun --nproc_per_node=2 train.py \\\n",
"!XLA_USE_BF16=1 torchrun --nproc_per_node=2 train.py \\\n",
" --model_id bert-base-uncased \\\n",
" --dataset_path lm_dataset \\\n",
" --lr 5e-5 \\\n",
" --per_device_train_batch_size 16 \\\n",
" --bf16 True \\\n",
" --epochs 3"
" --per_device_train_batch_size 8 \\\n",
" --bf16 True"
]
},
{
Expand Down
33 changes: 28 additions & 5 deletions notebooks/text-classification/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import os
from dataclasses import dataclass, field

import evaluate
import numpy as np
Expand Down Expand Up @@ -34,9 +35,10 @@ def parse_args():
)
# add training hyperparameters for epochs, batch size, learning rate, and seed
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for.")
parser.add_argument("--max_steps", type=int, default=-1, help="Number of steps to train for.")
parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size to use for training.")
parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size to use for testing.")
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate to use for training.")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate to use for training.")
parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")
parser.add_argument(
"--bf16",
Expand Down Expand Up @@ -93,18 +95,19 @@ def training_function(args):
training_args = TrainingArguments(
overwrite_output_dir=True,
output_dir=output_dir,
do_train=True,
do_eval=False,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
bf16=args.bf16, # Use BF16 if available
learning_rate=args.lr,
learning_rate=args.learning_rate,
num_train_epochs=args.epochs,
max_steps=args.max_steps,
# logging & evaluation strategies
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=500,
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
save_strategy="steps",
# push to hub parameters
report_to="tensorboard",
push_to_hub=True if args.repository_id else False,
Expand All @@ -113,6 +116,10 @@ def training_function(args):
hub_token=args.hf_token,
)

# from rich import print
# print(training_args)
# breakpoint()

# Create Trainer instance
trainer = Trainer(
model=model,
Expand All @@ -137,7 +144,23 @@ def training_function(args):
trainer.push_to_hub()


@dataclass
class ModelArguments:
model_id: str = field(default="bert-large-uncased",
metadata={"help": "Model id to use for training."})
dataset_path: str = field(default="dataset",
metadata={"help": "Path to the already processed dataset."})
epochs: int = field(default=3,
metadata={"help": "Number of epochs to train for."})


def main():
# from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser
# parser = HfArgumentParser((ModelArguments, TrainingArguments))
# model_args, training_args = parser.parse_args_into_dataclasses()
# print(training_args)
# breakpoint()

args, _ = parse_args()
training_function(args)

Expand Down

0 comments on commit a7f9152

Please sign in to comment.