Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger-kernel] Add an option to use _apply_liger_kernel_to_instance() to load model #133

Merged
merged 9 commits into from
Jan 30, 2025
52 changes: 52 additions & 0 deletions tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
set -x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this test to CI and make sure liger-kernel is toggled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added this test to .github/workflows/e2e_gsm8k.yml

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess liger-kernel has to be added to setup.py and project.toml


export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
+actor_rollout_ref.model.use_liger=True \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag for use_liger is here.

actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
+trainer.val_before_train=False \
trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
14 changes: 14 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _build_model_optimizer(self,
use_remove_padding=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role='actor'):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
Expand Down Expand Up @@ -193,6 +194,17 @@ def _build_model_optimizer(self,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
try:
# Import Liger kernel module and use it to load the model
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=actor_module)
except ImportError:
# Fallback to use AutoModelForCausalLM and print warning message
logger.warning("Liger kernel was requested but not installed - falling back to AutoModelForCausalLM")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would prefer for the job to fail outright if liger is requested but not installed. Just printing a warning is too easy to miss in all the job outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your proposal. I think we can remove the try ... catch ... logic and add liger_kernel a required dependency of verl. There are still a few extra PRs in Liger kernel targeting for the full integration. Once those are done, I think it makes sense to add liger-kernel to the requirements.txt cc @vermouth1992 @eric-haibin-lin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's a good idea to make liger-kernel a required dependency and remove try except.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove try except and run the CI again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done.

logger.warning("To enable Liger kernel, install it with: pip install liger-kernel")

# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)

Expand Down Expand Up @@ -333,6 +345,7 @@ def init_model(self):
use_remove_padding=use_remove_padding,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=self.config.model.get('trust_remote_code', False),
use_liger=self.config.model.get('use_liger', False),
role='actor')

# get the original unwrapped module
Expand Down Expand Up @@ -365,6 +378,7 @@ def init_model(self):
use_remove_padding=use_remove_padding,
trust_remote_code=self.config.model.get(
'trust_remote_code', False),
use_liger=self.config.model.get('use_liger', False),
role='ref')[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
Expand Down