Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix orpo/dpo trainer #1286

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

dame-cell
Copy link

@dame-cell dame-cell commented Nov 13, 2024

This draft is a temporary fix to this issue 1285

Since the latest version of trl 0.12.0 now takes in processing_class instead of tokenizer
So, we need to change

from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from unsloth import is_bfloat16_supported

# the newest version of trl now uses processing_class instead of tokenizer

dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,
    args=DPOConfig(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_ratio=0.1,
        num_train_epochs=3,
        learning_rate=5e-6,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.0,
        lr_scheduler_type="linear",
        seed=42,
        output_dir="outputs",
        report_to="none",  # Use this for WandB etc.
    ),
    beta=0.1,
    train_dataset=raw_datasets["train"],
    #tokenizer=tokenizer,
    processing_class=tokenizer, 
    max_length=1024,
    max_prompt_length=512,
)

And for some reason the unsloth FastLanguageModel.from_pretrained tokenizer does not work well with the processsing_class so we need to import original tokenizer

## for the DPO colab 
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("unsloth/zephyr-sft-bnb-4bit")
## For the ORPO colab notebook 
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-bnb-4bit")

instead of

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

@danielhanchen

@dame-cell dame-cell marked this pull request as ready for review November 13, 2024 14:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant