Skip to content

Commit

Permalink
Εnable reward model offloading option (deepspeedai#930)
Browse files Browse the repository at this point in the history
* enable reward model offloading option

* fixed code formatting

* more formatting fixes

* Pre-commit formatting fix

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: zhangsmallshark <[email protected]>
  • Loading branch information
4 people authored and zhangsmallshark committed Feb 12, 2025
1 parent cab3361 commit 18450b0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
17 changes: 4 additions & 13 deletions applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,22 @@ def _init_reward(self, critic_model_name_or_path):
# If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
ds_config = get_eval_ds_config(offload=self.args.offload_reward_model,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_eval_config[
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

# Model
reward_model = create_critic_model(
model_name_or_path=critic_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_eval_config,
ds_config=ds_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
dropout=self.args.critic_dropout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def parse_args():
'--offload_reference_model',
action='store_true',
help='Enable ZeRO Offload techniques for reference model')
parser.add_argument('--offload_reward_model',
action='store_true',
help='Enable ZeRO Offload techniques for reward model')
parser.add_argument(
'--actor_zero_stage',
type=int,
Expand Down

0 comments on commit 18450b0

Please sign in to comment.