From e5cbb21498ee7c42e5d744ed2dc105f2bb39719a Mon Sep 17 00:00:00 2001 From: YeZhengMao Date: Mon, 29 Jul 2024 17:14:49 +0800 Subject: [PATCH] [benchmark] add fsdp, tp , pp benchmark (#243) --- benchmarks/bench_mlora.py | 161 ---------------------------- benchmarks/bench_mlora_pp.py | 169 ++++++++++++++++++++++++++++++ benchmarks/bench_peft.py | 167 ----------------------------- benchmarks/bench_peft_fsdp.py | 179 ++++++++++++++++++++++++++++++++ benchmarks/bench_peft_tp.py | 161 ++++++++++++++++++++++++++++ mlora/config/task.py | 3 + mlora/executor/task/__init__.py | 31 +++++- mlora/utils/cmd.py | 4 +- pyproject.mlora.toml | 2 +- 9 files changed, 545 insertions(+), 332 deletions(-) delete mode 100644 benchmarks/bench_mlora.py create mode 100644 benchmarks/bench_mlora_pp.py delete mode 100644 benchmarks/bench_peft.py create mode 100644 benchmarks/bench_peft_fsdp.py create mode 100644 benchmarks/bench_peft_tp.py diff --git a/benchmarks/bench_mlora.py b/benchmarks/bench_mlora.py deleted file mode 100644 index 8e2f4d4f..00000000 --- a/benchmarks/bench_mlora.py +++ /dev/null @@ -1,161 +0,0 @@ -from mlora.utils import setup_seed -from mlora.config import LoRAConfig -from mlora.model.args import MLoRABatchData, MLoRADataConfig -from mlora.profiler.profiler import setup_trace_mode, set_backward_tracepoint, grad_fn_nvtx_wrapper_by_tracepoint, nvtx_range - -import mlora -import torch -import random -import argparse - -from typing import List - -# Command Line Arguments -parser = argparse.ArgumentParser(description='PEFT benchmarks') -parser.add_argument('--base_model', type=str, required=True, - help='Path to or name of base model') -parser.add_argument('--device', type=str, default='cuda:0', - help='Specify which GPU to be used, default is cuda:0') -# load quant -parser.add_argument('--load_8bit', action="store_true", - help='Load model in 8bit mode') -parser.add_argument('--load_4bit', action="store_true", - help='Load model in 4bit mode') -# lora test number -parser.add_argument('--lora_cnt', type=int, default=4, - help='The number of lora') -# test configure -parser.add_argument('--warmup', type=int, default=100, - help="The step of warm up") -parser.add_argument('--repete', type=int, default=100, - help="Total test iteration") -parser.add_argument('--seq_len', type=int, default=128, - help="The length of the sequence") -parser.add_argument('--batch_size', type=int, default=8, - help="The batch size of each lora input") - - -g_default_rank = 16 -g_default_alpha = 16 -g_default_dropout = 0.05 -g_default_target_modules = {"q_proj": True, - "k_proj": True, - "v_proj": True, - "o_proj": True, - "w1_proj": False, - "w2_proj": False, - "w3_proj": False} -g_default_loss_fn = torch.nn.CrossEntropyLoss() - -args = parser.parse_args() -assert not (args.load_4bit and args.load_8bit) - - -def setup_lora_adapter_config() -> List[LoRAConfig]: - lora_config: List[LoRAConfig] = [] - - for idx in range(0, args.lora_cnt): - lora_config.append(LoRAConfig({ - "name": f"lora_{idx}", - "r": g_default_rank, - "alpha": g_default_alpha, - "dropout": g_default_dropout, - "target_modules": g_default_target_modules, - "batch_size": args.batch_size, - "micro_batch_size": args.batch_size, - # unused - "test_batch_size": 0, - "num_epochs": 0, - "data": "", - "test_data": "", - "prompt": "", - "group_by_length": "", - "expand_side": "", - "optim": "sgd", - "momentum": 0.0, - "lr": 0.0, - })) - - return lora_config - - -def setup_input() -> MLoRABatchData: - batch_tokens = [] - additional_masks = [] - lora_batch_data_config: List[MLoRADataConfig] = [] - - start_idx = 0 - end_idx = 0 - - for lora_idx in range(0, args.lora_cnt): - adapter_name = f"lora_{lora_idx}" - - for _ in range(0, args.batch_size): - tokens = [random.randint(1, 10000) for _ in range(args.seq_len)] - batch_tokens.append(tokens) - additional_masks.append([False] * args.seq_len) - end_idx += 1 - - lora_batch_data_config.append(MLoRADataConfig( - adapter_name_=adapter_name, - batch_start_idx_=start_idx, - batch_end_idx_=end_idx, - )) - - start_idx = end_idx - - return MLoRABatchData(batch_tokens_=batch_tokens, - batch_mask_=additional_masks, - lora_batch_data_config_=lora_batch_data_config, - inference_model_=False) - - -def calc_loss(train_data: MLoRABatchData, model_output: torch.Tensor) -> torch.Tensor: - labels = torch.tensor(train_data.batch_tokens_, dtype=torch.long) - total_loss = None - - for lora_config in train_data.lora_batch_data_config_: - start_idx = lora_config.batch_start_idx_ - end_idx = lora_config.batch_end_idx_ - vocab_size = model_output.shape[-1] - loss_input = model_output[start_idx:end_idx][..., - :-1, :].contiguous().view(-1, vocab_size) - loss_target = labels[start_idx:end_idx][..., - 1:].contiguous().view(-1).to(loss_input.device) - loss = g_default_loss_fn(loss_input, loss_target) - if total_loss is None: - total_loss = loss - else: - total_loss += loss - - return total_loss - - -if __name__ == "__main__": - input_data = setup_input() - - setup_seed(42) - - _, model = mlora.load_base_model(args.base_model, - "llama", - args.device, - args.load_4bit, - args.load_8bit, - None) - - mlora.init_lora_model(model, setup_lora_adapter_config()) - - # to wramup - for test_idx in range(0, args.warmup): - output = model.forward(input_data) - - setup_trace_mode() - - for _ in range(0, args.repete): - output = model.forward(input_data) - with nvtx_range("f_calc_loss"): - total_loss = calc_loss(input_data, output) - set_backward_tracepoint(total_loss.grad_fn, "b_loss") - grad_fn_nvtx_wrapper_by_tracepoint(total_loss.grad_fn) - - total_loss.backward() diff --git a/benchmarks/bench_mlora_pp.py b/benchmarks/bench_mlora_pp.py new file mode 100644 index 00000000..0151c20d --- /dev/null +++ b/benchmarks/bench_mlora_pp.py @@ -0,0 +1,169 @@ +import logging +import random +import time +from typing import List, Tuple, override + +import mlora.config +import mlora.executor +import mlora.executor.dispatcher +import mlora.model +import mlora.utils +import torch +from mlora.config.adapter import LoRAConfig +from mlora.config.dispatcher import DispatcherConfig +from mlora.config.task import TrainTaskConfig +from mlora.executor.task import TrainTask, register_task_class +from mlora.model.args import MLoRADataConfig, Tokens + +g_start_time: int | None = None +g_total_token: int = 0 + + +class BenchmarkArgs: + batch_size_: int = 8 + seq_len_: int = 512 + concurrency_num_: int = 4 + test_epochs_: int = 20 + + +class BenchmarkConfig(mlora.config.MLoRAConfig): + dispatcher_: DispatcherConfig + + def __init__(self): + # just for init + self.dispatcher_ = DispatcherConfig( + {"name": "pipe", "concurrency_num": BenchmarkArgs.concurrency_num_} + ) + + +class BenchmarkTask(TrainTask): + def __init__(self, config: mlora.config.TaskConfig, llm_name: str) -> None: + super().__init__(config, llm_name) + + @override + def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: + + ret_tokens = [ + [random.randint(1, 10000) for _ in range(BenchmarkArgs.seq_len_)] + ] * BenchmarkArgs.batch_size_ + + end_idx = start_idx + len(ret_tokens) + + def loss_fn( + input: torch.Tensor, target: torch.Tensor, _: torch.Tensor + ) -> torch.Tensor: + vacab_size = input.shape[-1] + loss_input = ( + input[start_idx:end_idx, :-1, :].contiguous().view(-1, vacab_size) + ) + loss_target = ( + target[start_idx:end_idx, 1:] + .contiguous() + .view(-1) + .to(loss_input.device) + ) + loss: torch.Tensor = self.context_.loss_fn_(loss_input, loss_target) + + return loss + + data_config = MLoRADataConfig( + self.context_.name_, + self.context_.type_, + start_idx, + end_idx, + self._expand_batch_tokens, + loss_fn, + self.task_name(), + ) + + return ret_tokens, [data_config] + + @override + def step(self): + self.now_step_ += 1 + self.now_epoch_ += 1 + self.context_.step() + + global g_start_time + global g_total_token + if g_start_time is not None: + g_total_token = g_total_token + ( + BenchmarkArgs.batch_size_ * BenchmarkArgs.seq_len_ + ) + logging.info( + f"average {g_total_token / (time.time() - g_start_time) : .2f} tokens/s" + ) + else: + g_start_time = time.time() + + logging.info(f"task {self.context_.name_} step") + + +def generate_task_config(task_idx: int) -> TrainTaskConfig: + adapters = { + f"test_{task_idx}": LoRAConfig( + { + "type": "lora", + "name": f"test_{task_idx}", + "path": f"adapters/test_{task_idx}", + "r": 16, + "alpha": 16, + "dropout": 0.05, + "target_modules": { + "q_proj": True, + "k_proj": True, + "v_proj": True, + "o_proj": True, + }, + "optimizer": "adamw", + "lr": 1e-3, + } + ) + } + datasets = {f"test_{task_idx}": None} + + return TrainTaskConfig( + { + "batch_size": BenchmarkArgs.batch_size_, + "mini_batch_size": BenchmarkArgs.batch_size_, + "num_epochs": BenchmarkArgs.test_epochs_, + "cutoff_len": BenchmarkArgs.seq_len_, + "save_step": 10000000, + "name": f"test_{task_idx}", + "type": "benchmark", + "adapter": f"test_{task_idx}", + "dataset": f"test_{task_idx}", + }, + adapters, + datasets, + ) + + +if __name__ == "__main__": + args = mlora.utils.get_cmd_args() + + mlora.utils.setup_seed(args.seed) + mlora.utils.setup_logging(args.log_level, args.log_file) + mlora.utils.setup_cuda_check() + + register_task_class("benchmark", BenchmarkTask) + + # enable the trace mode for profiling performance + if args.trace: + mlora.utils.setup_trace_mode() + + tokenizer, model = mlora.model.load_model(args) + + config = BenchmarkConfig() + + # init all task from config file + executor = mlora.executor.PipeExecutor( + model, tokenizer, config, args.device, args.rank, args.balance, args.recompute + ) + + # only the header node can add task + if args.rank == 0: + for idx in range(0, BenchmarkArgs.concurrency_num_): + executor.add_task(generate_task_config(idx)) + + executor.execute() diff --git a/benchmarks/bench_peft.py b/benchmarks/bench_peft.py deleted file mode 100644 index c257a76c..00000000 --- a/benchmarks/bench_peft.py +++ /dev/null @@ -1,167 +0,0 @@ -from mlora.utils import setup_seed -from mlora.profiler.profiler import setup_trace_mode, grad_fn_nvtx_wrapper_by_tracepoint, set_backward_tracepoint - -import torch -import random -import argparse -import logging - -from transformers import LlamaForCausalLM -from peft import LoraConfig, TaskType, PeftModelForCausalLM, prepare_model_for_kbit_training - -# Command Line Arguments -parser = argparse.ArgumentParser(description='PEFT benchmarks') -parser.add_argument('--base_model', type=str, required=True, - help='Path to or name of base model') -parser.add_argument('--device', type=str, default='cuda:0', - help='Specify which GPU to be used, default is cuda:0') -# load quant -parser.add_argument('--load_8bit', action="store_true", - help='Load model in 8bit mode') -parser.add_argument('--load_4bit', action="store_true", - help='Load model in 4bit mode') -# lora test number -parser.add_argument('--lora_cnt', type=int, default=4, - help='The number of lora') -# test configure -parser.add_argument('--warmup', type=int, default=100, - help="The step of warm up") -parser.add_argument('--repete', type=int, default=100, - help="Total test iteration") -parser.add_argument('--seq_len', type=int, default=128, - help="The length of the sequence") -parser.add_argument('--batch_size', type=int, default=8, - help="The batch size of each lora input") -parser.add_argument('--peft_mode', type=str, default="seq", - help="How to use peft to train multi lora, include: seq, switch") - -g_default_rank = 16 -g_default_alpha = 16 -g_default_dropout = 0.05 -g_default_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] -g_default_loss_fn = torch.nn.CrossEntropyLoss() - -args = parser.parse_args() -assert not (args.load_4bit and args.load_8bit) - - -def setup_lora_adapter(llm_model: LlamaForCausalLM) -> PeftModelForCausalLM: - peft_llm_model = llm_model - - for idx in range(0, args.lora_cnt): - adapter_name = f"lora_{idx}" - lora_r = g_default_rank - lora_alpha = g_default_alpha - lora_dropout = g_default_dropout - lora_target = g_default_target_modules - peft_lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, - r=lora_r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - target_modules=lora_target, - bias="none", - inference_mode=False) - peft_llm_model = PeftModelForCausalLM( - peft_llm_model, peft_lora_config, adapter_name) - - return peft_llm_model - - -def setup_llm_model() -> LlamaForCausalLM: - load_bits = None - load_bits = 8 if args.load_8bit else load_bits - load_bits = 4 if args.load_4bit else load_bits - - qlora_4bit_fp16 = True - qlora_4bit_bf16 = False - qlora_4bit_double_quant = True - qlora_4_bit_quant_type = "nf4" - - additional_load_args = { - "device_map": args.device, - "torch_dtype": torch.float32 - } - - if load_bits is not None: - logging.info('Loading model with quantization, bits = %i' % load_bits) - from transformers import BitsAndBytesConfig - qlora_4bit_compute_dtype = torch.float32 - # if set the compute type, then change it, otherwise hold the default - qlora_4bit_compute_dtype = torch.float16 if qlora_4bit_fp16 else qlora_4bit_compute_dtype - qlora_4bit_compute_dtype = torch.bfloat16 if qlora_4bit_bf16 else qlora_4bit_compute_dtype - - torch_dtype = torch.float32 - torch_dtype = torch.bfloat16 if qlora_4bit_bf16 else torch_dtype - additional_load_args["torch_dtype"] = torch_dtype - additional_load_args["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True if load_bits == 4 else False, - load_in_8bit=True if load_bits == 8 else False, - llm_int8_enable_fp32_cpu_offload=True, - # only for qlora 4bit - bnb_4bit_compute_dtype=qlora_4bit_compute_dtype, - bnb_4bit_use_double_quant=qlora_4bit_double_quant, - bnb_4bit_quant_type=qlora_4_bit_quant_type, - ) - - llm_model = LlamaForCausalLM.from_pretrained( - args.base_model, **additional_load_args) - - llm_model = prepare_model_for_kbit_training(llm_model) - llm_model.training = True - llm_model.gradient_checkpointing_enable() - - return llm_model - - -def setup_labels() -> torch.Tensor: - batch_input_ids = [] - for _ in range(0, args.batch_size): - batch_input_ids.append([random.randint(1, 10000) - for _ in range(args.seq_len)]) - return torch.tensor(batch_input_ids, dtype=torch.long, device=args.device) - - -if __name__ == "__main__": - lables = setup_labels() - - setup_seed(42) - model: LlamaForCausalLM = setup_llm_model() - vocab_size = model.vocab_size - model: PeftModelForCausalLM = setup_lora_adapter(model) - model.train() - - # to wramup - for test_idx in range(0, args.warmup): - loss = model.forward(input_ids=lables, labels=lables)[0] - - setup_trace_mode() - - def lora_seq(): - for lora_idx in range(0, args.lora_cnt): - now_lora = f"lora_{lora_idx}" - model.set_adapter(now_lora) - for _ in range(0, args.repete): - loss = model.forward(input_ids=lables, labels=lables)[0] - set_backward_tracepoint(loss.grad_fn, "b_loss") - grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn) - loss.backward() - - def lora_switch(): - for _ in range(0, args.repete): - for lora_idx in range(0, args.lora_cnt): - now_lora = f"lora_{lora_idx}" - model.set_adapter(now_lora) - loss = model.forward(input_ids=lables, labels=lables)[0] - set_backward_tracepoint(loss.grad_fn, "b_loss") - grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn) - loss.backward() - - mode_function = { - "seq": lora_seq, - "switch": lora_switch, - } - - peft_mode = args.peft_mode - - assert peft_mode in mode_function, NotImplementedError - mode_function[peft_mode]() diff --git a/benchmarks/bench_peft_fsdp.py b/benchmarks/bench_peft_fsdp.py new file mode 100644 index 00000000..4790d3aa --- /dev/null +++ b/benchmarks/bench_peft_fsdp.py @@ -0,0 +1,179 @@ +import argparse +import os +import random +import time +from dataclasses import dataclass + +import llama_recipes +import llama_recipes.configs +import llama_recipes.utils.fsdp_utils +import llama_recipes.utils.train_utils +import torch +import torch.distributed.fsdp +import torch.optim as optim +from llama_recipes.policies import apply_fsdp_checkpointing +from peft import LoraConfig, PeftModelForCausalLM, TaskType, get_peft_model +from transformers import LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + + +@dataclass +class BenchmarkArgs: + batch_size: int = 8 + seq_len: int = 512 + test_steps: int = 100 + + +@dataclass +class PeftArgs: + rank = 16 + alpha = 16 + dropout = 0.05 + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] + + +def dummy_train_labels(benchmark_args: BenchmarkArgs) -> torch.Tensor: + batch_input_ids = [] + for _ in range(0, benchmark_args.batch_size): + batch_input_ids.append( + [random.randint(1, 10000) for _ in range(benchmark_args.seq_len)] + ) + return torch.tensor(batch_input_ids, dtype=torch.long) + + +def create_optimizer( + model: torch.distributed.fsdp.FullyShardedDataParallel, + train_config: llama_recipes.configs.train_config, +): + optimizer = optim.AdamW( + model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay + ) + return optimizer + + +def setup_lora_adapter( + model: LlamaForCausalLM, peft_args: PeftArgs +) -> PeftModelForCausalLM: + peft_lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=peft_args.rank, + lora_alpha=peft_args.alpha, + lora_dropout=peft_args.dropout, + target_modules=peft_args.target_modules, + bias="none", + inference_mode=False, + ) + model = get_peft_model(model, peft_lora_config) + + return model + + +def create_peft_model( + rank: int, + fsdp_config: llama_recipes.configs.fsdp_config, + train_config: llama_recipes.configs.train_config, +) -> torch.distributed.fsdp.FullyShardedDataParallel: + # init the backbone and lora adapter + model = LlamaForCausalLM.from_pretrained( + train_config.model_name, use_cache=False, torch_dtype=torch.float32 + ) + + peft_args = PeftArgs() + model = setup_lora_adapter(model, peft_args) + model.print_trainable_parameters() + + # setup FSDP + device_id = torch.cuda.current_device() + + auto_wrapping_policy = llama_recipes.utils.fsdp_utils.fsdp_auto_wrap_policy( + model, LlamaDecoderLayer + ) + mixed_precision_policy, wrapping_policy = ( + llama_recipes.utils.train_utils.get_policies(fsdp_config, rank) + ) + + model = torch.distributed.fsdp.FullyShardedDataParallel( + model, + auto_wrap_policy=( + auto_wrapping_policy if train_config.use_peft else wrapping_policy + ), + cpu_offload=None, + mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, + sharding_strategy=fsdp_config.sharding_strategy, + device_id=device_id, + limit_all_gathers=True, + sync_module_states=train_config.low_cpu_fsdp, + param_init_fn=lambda module: ( + module.to_empty(device=torch.device("cuda"), recurse=False) + if train_config.low_cpu_fsdp and rank != 0 + else None + ), + ) + + if fsdp_config.fsdp_activation_checkpointing: + apply_fsdp_checkpointing(model) + + return model + + +def init_args(): + parser = argparse.ArgumentParser(description="PEFT FSDP benchmarks") + parser.add_argument( + "--base_model", type=str, required=True, help="Path to or name of base model" + ) + return parser.parse_args() + + +def train( + model: torch.distributed.fsdp.FullyShardedDataParallel, + optimizer: optim.AdamW, + labels: torch.Tensor, + rank: int, + local_rank: int, + benchmark_args: BenchmarkArgs, +): + start_time = time.time() + total_tokens = 0 + + for _ in range(benchmark_args.test_steps): + data = labels.to(local_rank) + total_tokens += data.numel() + + # with torch.cuda.amp.autocast(): + loss = model(input_ids=data, labels=data).loss + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + if rank == 0: + print(f"average {total_tokens / (time.time() - start_time) : .2f} tokens/s") + + +if __name__ == "__main__": + args = init_args() + + llama_recipes.utils.train_utils.setup() + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + + torch.cuda.set_device(local_rank) + + llama_recipes.utils.train_utils.clear_gpu_cache(local_rank) + llama_recipes.utils.train_utils.setup_environ_flags(rank) + + fsdp_config = llama_recipes.configs.fsdp_config( + fsdp_activation_checkpointing=False, mixed_precision=False + ) + train_config = llama_recipes.configs.train_config( + model_name=args.base_model, enable_fsdp=True, use_peft=True + ) + + fsdp_model = create_peft_model(rank, fsdp_config, train_config) + optimizer = create_optimizer(fsdp_model, train_config) + + benchmark_args = BenchmarkArgs() + labels = dummy_train_labels(benchmark_args) + + train(fsdp_model, optimizer, labels, rank, local_rank, benchmark_args) diff --git a/benchmarks/bench_peft_tp.py b/benchmarks/bench_peft_tp.py new file mode 100644 index 00000000..92d06447 --- /dev/null +++ b/benchmarks/bench_peft_tp.py @@ -0,0 +1,161 @@ +import argparse +import os +import random +import time +from dataclasses import dataclass + +import torch +from peft import LoraConfig, PeftModelForCausalLM, TaskType, get_peft_model +from torch.distributed._tensor import DeviceMesh, Replicate, Shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) +from transformers import LlamaForCausalLM + + +@dataclass +class BenchmarkArgs: + batch_size: int = 24 + seq_len: int = 512 + test_steps: int = 10 + + +@dataclass +class PeftArgs: + rank = 16 + alpha = 16 + dropout = 0.05 + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] + + +def setup_lora_adapter( + model: LlamaForCausalLM, peft_args: PeftArgs +) -> PeftModelForCausalLM: + peft_lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=peft_args.rank, + lora_alpha=peft_args.alpha, + lora_dropout=peft_args.dropout, + target_modules=peft_args.target_modules, + bias="none", + inference_mode=False, + ) + model = get_peft_model(model, peft_lora_config) + + return model + + +def init_args(): + parser = argparse.ArgumentParser(description="PEFT TP benchmarks") + parser.add_argument( + "--base_model", type=str, required=True, help="Path to or name of base model" + ) + return parser.parse_args() + + +def dummy_train_labels() -> torch.Tensor: + batch_input_ids = [] + for _ in range(0, BenchmarkArgs.batch_size): + batch_input_ids.append( + [random.randint(1, 10000) for _ in range(BenchmarkArgs.seq_len)] + ) + return torch.tensor( + batch_input_ids, + dtype=torch.long, + device=torch.device(torch.cuda.current_device()), + ) + + +if __name__ == "__main__": + args = init_args() + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + print("To init the device mesh") + tp_mesh = DeviceMesh(device_type="cuda", mesh=[0, 1, 2, 3]) + + print("To load the llama model") + model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained( + args.base_model, use_cache=False, torch_dtype=torch.float32 + ) + for param in model.parameters(): + param.requires_grad_(False) + + model.lm_head = model.lm_head.to(torch.cuda.current_device()) + model.model.embed_tokens = model.model.embed_tokens.to(torch.cuda.current_device()) + model.model.norm = model.model.norm.to(torch.cuda.current_device()) + + print("To wrapper the model") + + for layer in model.model.layers: + layer_parallelize_plan = { + "self_attn.q_proj": ColwiseParallel(), + "self_attn.k_proj": ColwiseParallel(), + "self_attn.v_proj": ColwiseParallel(), + "self_attn.o_proj": RowwiseParallel(), + "mlp.gate_proj": ColwiseParallel(), + "mlp.up_proj": ColwiseParallel(), + "mlp.down_proj": RowwiseParallel(), + } + + layer.input_layernorm = layer.input_layernorm.to(torch.cuda.current_device()) + layer.post_attention_layernorm = layer.post_attention_layernorm.to( + torch.cuda.current_device() + ) + layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to( + torch.cuda.current_device() + ) + + layer.self_attn.num_heads = model.model.config.num_attention_heads // world_size + layer.self_attn.num_key_value_heads = ( + model.model.config.num_key_value_heads // world_size + ) + layer.self_attn.hidden_size = model.model.config.hidden_size // world_size + + parallelize_module(layer, tp_mesh, layer_parallelize_plan) + + print("To wrapper peft model") + peft_args = PeftArgs() + model: PeftModelForCausalLM = setup_lora_adapter(model, peft_args) + model.print_trainable_parameters() + + for layer in model.base_model.model.model.layers: + layer_parallelize_plan = { + "self_attn.q_proj.lora_A.default": RowwiseParallel( + input_layouts=Replicate() + ), + "self_attn.k_proj.lora_A.default": RowwiseParallel( + input_layouts=Replicate() + ), + "self_attn.v_proj.lora_A.default": RowwiseParallel( + input_layouts=Replicate() + ), + "self_attn.o_proj.lora_A.default": ColwiseParallel(input_layouts=Shard(-1)), + "self_attn.q_proj.lora_B.default": ColwiseParallel(), + "self_attn.k_proj.lora_B.default": ColwiseParallel(), + "self_attn.v_proj.lora_B.default": ColwiseParallel(), + "self_attn.o_proj.lora_B.default": RowwiseParallel(), + } + parallelize_module(layer, tp_mesh, layer_parallelize_plan) + + optimizer = torch.optim.AdamW(model.parameters()) + + start_time = time.time() + total_tokens = 0 + + for step in range(BenchmarkArgs.test_steps): + data = dummy_train_labels().to(local_rank) + total_tokens += data.numel() + + loss = model(input_ids=data, labels=data).loss + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + print(f"average {total_tokens / (time.time() - start_time) : .2f} tokens/s") diff --git a/mlora/config/task.py b/mlora/config/task.py index 4e46a9fb..949d2f26 100644 --- a/mlora/config/task.py +++ b/mlora/config/task.py @@ -10,6 +10,9 @@ class TaskConfig(DictConfig): name_: str type_: str + adapter_: AdapterConfig + dataset_: DatasetConfig | None + __params_map: Dict[str, str] = { "name_": "name", "type_": "type", diff --git a/mlora/executor/task/__init__.py b/mlora/executor/task/__init__.py index 1898629c..94675ea3 100644 --- a/mlora/executor/task/__init__.py +++ b/mlora/executor/task/__init__.py @@ -1,9 +1,36 @@ +import logging +from typing import MutableMapping, Type + from .cit_task import CITTask from .cpo_task import CPOTask from .dpo_task import DPOTask from .task import Task from .train_task import TrainTask -TASK_CLASS = {"train": TrainTask, "dpo": DPOTask, "cpo": CPOTask, "cit": CITTask} +TASK_CLASS: MutableMapping[str, Type[Task]] = { + "train": TrainTask, + "dpo": DPOTask, + "cpo": CPOTask, + "cit": CITTask, +} + + +def register_task_class(type_name: str, task: Type[Task]): + global TASK_CLASS + + if type_name in TASK_CLASS: + logging.info(f"Task type {type_name} already exist skip register it.") + return + + TASK_CLASS[type_name] = task + -__all__ = ["Task", "TASK_CLASS", "TrainTask", "DPOTask", "CPOTask", "CITTask"] +__all__ = [ + "Task", + "TASK_CLASS", + "TrainTask", + "DPOTask", + "CPOTask", + "CITTask", + "register_task_class", +] diff --git a/mlora/utils/cmd.py b/mlora/utils/cmd.py index 9cf45de1..bdc87bd2 100644 --- a/mlora/utils/cmd.py +++ b/mlora/utils/cmd.py @@ -35,7 +35,9 @@ def _add_base_cmd(parser): parser.add_argument("--rank", type=int, default=-1, help="The device's rank number") parser.add_argument("--balance", type=int, nargs="+", help="The model's balance") parser.add_argument( - "--recompute", type=bool, default=True, help="Enable recompute to save memory" + "--recompute", + action=argparse.BooleanOptionalAction, + help="Enable recompute to save memory", ) # configuration about log parser.add_argument( diff --git a/pyproject.mlora.toml b/pyproject.mlora.toml index 37fcc4df..849dea71 100644 --- a/pyproject.mlora.toml +++ b/pyproject.mlora.toml @@ -24,7 +24,7 @@ dependencies = [ [project.optional-dependencies] ci_test = ["pytest", "flake8", "lizard", "black", "isort", "mypy"] -test = ["peft", "setuptools"] +test = ["peft", "setuptools", "llama_recipes"] debug = ["graphviz"] deploy = ["fastapi", "plyvel", "uvicorn"]