You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I am trying to fine-tune models in extremely long contexts.
I've tested the training setup below, and I managed to finetune:
llama3.1-1B with a max_sequence_length of 128 * 1024 tokens
Qwen2.5-Coder-1.5B-Instruct-bnb-4bit / Qwen2.5-Coder-0.5B-Instruct-bnb-4bit with a max_seq_length of 64 * 1024 tokens.
I would really like to reach a context length of 128K tokens also for Qwen; however, I get an OoO error (even for the smallest 0.5B model). Is there something else I can do to optimize training over long contexts?
Furthermore, why do I get no memory error when fine-tuning on llama3.1-1B, which has double the parameters?
My codebase is:
# Load the model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_NAME, # "unsloth/Llama-3.2-1B-Instruct" or "Qwen2.5-Coder-1.5B-Instruct-bnb-4bit"
dtype = 'Bfloat16',
load_in_4bit = load_in_4bit,
max_seq_length=max_seq_length,
)
model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
# Load the dataset, log its fingerprint and metadata and shuffle it
dataset = Dataset.load_from_disk(DATASET_PATH)
# Define the training config
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset['train'],
eval_dataset= dataset['test'],
dataset_text_field = "text",
max_seq_length = tokenizer.model_max_length,
data_collator = collator,
dataset_num_proc=1,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 1,
per_device_eval_batch_size = 1,
gradient_accumulation_steps = 8,
eval_accumulation_steps=1,
warmup_steps = 5,
eval_steps = 4,
num_train_epochs = 1,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
fp16_full_eval = True,
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
eval_strategy="steps"
),
)
trainer.train()
```
The text was updated successfully, but these errors were encountered:
GianlucaDeStefano
changed the title
Extremelly long context finetuning
Extremely long context finetuning
Nov 14, 2024
LLama 3.1 1B can reach 128K tokens, while Qwn 0.5B can only reach 64K.
I'm aware that this may also depend on the internal settings of the models; I was wondering if there is any particular setting that can help me to further optimize the training when using super long examples (~128K tokens).
Despite trying various approaches, I’ve been unable to exceed these limits on a single GPU. Scaling up to multiple GPUs would address the issue, but unfortunately, Unsloth does not yet support multi-GPU training. Additionally, switching to other frameworks isn’t an option since handling such large sequence lengths requires tensor parallelism to fit individual layers into GPU memory—a process that is complex to configure.
Hi all, I am trying to fine-tune models in extremely long contexts.
I've tested the training setup below, and I managed to finetune:
I would really like to reach a context length of 128K tokens also for Qwen; however, I get an OoO error (even for the smallest 0.5B model). Is there something else I can do to optimize training over long contexts?
Furthermore, why do I get no memory error when fine-tuning on llama3.1-1B, which has double the parameters?
My codebase is:
The text was updated successfully, but these errors were encountered: