Replies: 2 comments 3 replies
-
Interesting, thanks for the comment. The model weights shouldn't be updated though since only the policy model is passed to the optimizer: optimizer = torch.optim.AdamW(policy_model.parameters(), lr=5e-6, weight_decay=0.01)
...
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients But it's still bad from an efficiency perspective. I believe that the following change should improve it: Before: def compute_dpo_loss_batch(batch, policy_model, reference_model, beta):
"""Compute the DPO loss on an input batch"""
# where policy_model(batch["chosen"]) are the logits
policy_chosen_log_probas = compute_logprobs(
logits=policy_model(batch["chosen"]),
labels=batch["chosen"],
selection_mask=batch["chosen_mask"]
)
policy_rejected_log_probas = compute_logprobs(
logits=policy_model(batch["rejected"]),
labels=batch["rejected"],
selection_mask=batch["rejected_mask"]
)
ref_chosen_log_probas = compute_logprobs(
logits=reference_model(batch["chosen"]),
labels=batch["chosen"],
selection_mask=batch["chosen_mask"]
)
ref_rejected_log_probas = compute_logprobs(
logits=reference_model(batch["rejected"]),
labels=batch["rejected"],
selection_mask=batch["rejected_mask"]
) After: def compute_dpo_loss_batch(batch, policy_model, reference_model, beta):
"""Compute the DPO loss on an input batch"""
# where policy_model(batch["chosen"]) are the logits
policy_chosen_log_probas = compute_logprobs(
logits=policy_model(batch["chosen"]),
labels=batch["chosen"],
selection_mask=batch["chosen_mask"]
)
policy_rejected_log_probas = compute_logprobs(
logits=policy_model(batch["rejected"]),
labels=batch["rejected"],
selection_mask=batch["rejected_mask"]
)
with torch.no_grad():
ref_chosen_log_probas = compute_logprobs(
logits=reference_model(batch["chosen"]),
labels=batch["chosen"],
selection_mask=batch["chosen_mask"]
)
ref_rejected_log_probas = compute_logprobs(
logits=reference_model(batch["rejected"]),
labels=batch["rejected"],
selection_mask=batch["rejected_mask"]
) Could you give this a try? |
Beta Was this translation helpful? Give feedback.
-
Yes! Gradients are no longer evaluated for the reference model. Also the memory requirements are now drastically reduced. This is a good patch that may be merged into master. On another note (not trying to mix topics here), you mentioned somewhere that you intend to also implement RLHF in the future. Is that still the plan? |
Beta Was this translation helpful? Give feedback.
-
I have noticed that the weights of the reference model are being updated during DPO training in ch07 section 04_preference-tuning-with-dpo. You can see that for example by looking at reference_model.out_head.weight.grad which is not None.
My understanding was that the reference model does not get gradient updates and only the policy model is being changed. If that was the case, there would be no need to compute the gradients for the reference model but it seems like they are computed nonetheless.
Could you please clarify why this is the case?
Beta Was this translation helpful? Give feedback.
All reactions