Skip to content

Commit

Permalink
Merge branch 'main' into vllm-0.7-npu
Browse files Browse the repository at this point in the history
  • Loading branch information
as12138 authored Feb 26, 2025
2 parents d36c1c7 + b4c13ce commit cf73a3a
Show file tree
Hide file tree
Showing 63 changed files with 584 additions and 230 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,22 @@ jobs:
run: |
pip3 install hf_transfer
torchrun --nproc_per_node=8 tests/checkpoint/test_fsdp_ckpt.py
- name: Running transformers ulysses tests on 8 L20 GPUs + latest transformers
run: |
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.48.0
run: |
pip3 install transformers==4.48.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.47.0
run: |
pip3 install transformers==4.47.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.46.0
run: |
pip3 install transformers==4.46.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.45.0
run: |
pip3 install transformers==4.45.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
5 changes: 5 additions & 0 deletions .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ jobs:
pip3 install --upgrade vllm
cd tests/rollout
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py
- name: Run QWen 0.5B generation test
run: |
cd tests/generation
bash ./run_gen_qwen05.sh 4 $HOME/data/gen/qwen_05_gen_test.parquet
rm -rf $HOME/data/gen/qwen_05_gen_test.parquet
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ verl is fast with:
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
- huggingface models support
- Supervised fine-tuning
- Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer), [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer), [ReMax](https://github.com/volcengine/verl/tree/main/examples/remax_trainer), Reinforce++, etc
- Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer), [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer), [ReMax](https://github.com/volcengine/verl/tree/main/examples/remax_trainer), Reinforce++, [RLOO](https://github.com/volcengine/verl/tree/main/examples/rloo_trainer/run_qwen2-7b.sh), etc
- Support model-based reward and function-based reward (verifiable reward)
- flash-attention, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [long context](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh)
- scales up to 70B models and hundreds of GPUs
Expand All @@ -52,7 +52,7 @@ verl is fast with:
## Upcoming Features
- Reward model training
- DPO training
- DeepSeek integration with Megatron backend
- DeepSeek integration with Megatron v0.11
- SGLang integration
- vision language model RL

Expand Down Expand Up @@ -88,6 +88,7 @@ verl is fast with:
- [Deployment using Separate GPU Resources](https://github.com/volcengine/verl/tree/main/examples/split_placement)

**Blogs from the community**
- [使用verl进行GRPO分布式强化学习训练最佳实践](https://www.volcengine.com/docs/6459/1463942)
- [HybridFlow veRL 原文浅析](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/readme.md)
- [最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!](https://team.doubao.com/en/blog/%E6%9C%80%E9%AB%98%E6%8F%90%E5%8D%8720%E5%80%8D%E5%90%9E%E5%90%90%E9%87%8F-%E8%B1%86%E5%8C%85%E5%A4%A7%E6%A8%A1%E5%9E%8B%E5%9B%A2%E9%98%9F%E5%8F%91%E5%B8%83%E5%85%A8%E6%96%B0-rlhf-%E6%A1%86%E6%9E%B6-%E7%8E%B0%E5%B7%B2%E5%BC%80%E6%BA%90)

Expand Down
7 changes: 2 additions & 5 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Data
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
Expand All @@ -39,8 +38,6 @@ Data
algorithms (e.g. PPO) generates up to this length
- ``data.train_batch_size``: Batch size sampled for one training
iteration of different RL algorithms.
- ``data.val_batch_size``: Batch size sampled for one validation
iteration.
- ``data.return_raw_input_ids``: Whether to return the original
input_ids without adding chat template. This is mainly used to
accommodate situations where the reward model's chat template differs
Expand Down Expand Up @@ -130,7 +127,7 @@ Actor/Rollout/Reference Policy
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
n: 1 # > 1 for grpo, rloo
**Common config for actor, rollout and reference model**

Expand Down Expand Up @@ -328,7 +325,7 @@ Algorithm
- ``gemma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``.
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``rloo``
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse`` and ``full``. How to
calculate the kl divergence between actor and reference policy. For
specific options, refer to `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py#L192>`_ .
Expand Down
1 change: 0 additions & 1 deletion docs/examples/gsm8k_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ The script of run_deepseek7b_llm.sh
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
14 changes: 14 additions & 0 deletions docs/faq/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ How to run multi-node post-training with Ray?

You can start a ray cluster and submit a ray job, following the official guide from Ray: https://docs.ray.io/en/latest/ray-core/starting-ray.html

Then in the configuration, set the ``trainer.nnode`` config to the number of machines for your job.

How to use verl on a Slurm-managed cluster?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand All @@ -41,3 +43,15 @@ manager available on your cluster or use other container runtimes (e.g. through

Please note that Slurm cluster setup may vary. If you encounter any issues, please refer to Ray's
`Slurm user guide <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ for common caveats.

Illegal memory access
---------------------------------

If you encounter the error message like ``CUDA error: an illegal memory access was encountered`` during rollout, most likely it is due to a known issue from vllm.
Please set the following environment variable. The env var must be set before the ``ray start`` command if any.

.. code:: bash
export VLLM_ATTENTION_BACKEND=XFORMERS
If in doubt, print this env var in each rank to make sure it is properly set.
1 change: 0 additions & 1 deletion docs/start/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.pa
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek7b_llm_sp2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=512 \
data.val_batch_size=128 \
data.max_prompt_length=128 \
data.max_response_length=128 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-6.7b-instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=$HOME/models/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_gemma.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=512 \
data.val_batch_size=1312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=google/gemma-2-2b-it \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=4096 \
data.val_batch_size=1312 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=4096 \
data.val_batch_size=1312 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2.5-32b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6304 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
Expand Down
14 changes: 7 additions & 7 deletions examples/ppo_trainer/verl_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,16 @@
"source": [
"import torch\n",
"try:\n",
" assert torch.cuda.is_available() is True\n",
" torch.ones(1, dtype=torch.bfloat16).cuda()\n",
" assert torch.cuda.is_available() is True\n",
" torch.ones(1, dtype=torch.bfloat16).cuda()\n",
"except AssertionError:\n",
" print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n",
" print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n",
"\n",
"try:\n",
" import verl\n",
" import verl\n",
"except Exception as e:\n",
" print(\"Please install verl via pip and restart the kernel\")\n",
" raise e\n",
" print(\"Please install verl via pip and restart the kernel\")\n",
" raise e\n",
"\n",
"import flash_attn"
]
Expand Down Expand Up @@ -561,6 +561,7 @@
"source": [
"import inspect\n",
"from verl.utils.reward_score.gsm8k import compute_score as gsm8k_reward\n",
"\n",
"print(inspect.getsource(gsm8k_reward))"
]
},
Expand Down Expand Up @@ -1103,7 +1104,6 @@
" data.train_files=$HOME/data/gsm8k/train.parquet \\\n",
" data.val_files=$HOME/data/gsm8k/test.parquet \\\n",
" data.train_batch_size=256 \\\n",
" data.val_batch_size=1312 \\\n",
" data.max_prompt_length=512 \\\n",
" data.max_response_length=256 \\\n",
" actor_rollout_ref.model.path=$HOME/models/Qwen2.5-0.5B-Instruct \\\n",
Expand Down
1 change: 0 additions & 1 deletion examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/train.parquet \
data.train_batch_size=512 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/train.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
Expand Down
40 changes: 40 additions & 0 deletions examples/rloo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
set -x

export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=rloo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_rloo_example_gsm8k' \
trainer.experiment_name='qwen2_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
1 change: 0 additions & 1 deletion examples/slurm/ray_on_slurm.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "$head_node" \
data.train_files=$train_files \
data.val_files=$val_files \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
Expand Down
2 changes: 1 addition & 1 deletion examples/split_placement/config/ppo_trainer_split.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
Expand Down
1 change: 0 additions & 1 deletion examples/split_placement/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 main_ppo_split.py \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
Loading

0 comments on commit cf73a3a

Please sign in to comment.