Skip to content

Commit

Permalink
tested for 1B
Browse files Browse the repository at this point in the history
  • Loading branch information
Guitaricet committed Jul 21, 2023
1 parent bf1873e commit 7ac8e53
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 67 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ All requirements are listed in `requirements.txt` and kept up-to-date.

```bash
cd peft_pretraining
pip install -r requirements.txt
pip install -e .
```

## 1B training script
Expand Down
20 changes: 20 additions & 0 deletions configs/llama_250m_old.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"architectures": [
"LLaMAForCausalLM"
],
"bos_token_id": 0,
"eos_token_id": 1,
"hidden_act": "silu",
"hidden_size": 768,
"intermediate_size": 2560,
"initializer_range": 0.02,
"max_sequence_length": 1024,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": -1,
"rms_norm_eps": 1e-06,
"transformers_version": "4.28.1",
"use_cache": true,
"vocab_size": 32000
}
2 changes: 1 addition & 1 deletion peft_pretraining/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_args_torchrun_main(args):
logger.info(f"Training for {args.num_training_steps} update steps")

if args.warmed_up_model is not None:
assert os.path.exists(args.contwarmed_up_modelinue_from), f"{args.warmed_up_model=} does not exist"
assert os.path.exists(args.warmed_up_model), f"{args.warmed_up_model=} does not exist"

if args.dtype in ["fp16", "float16"]:
raise NotImplementedError("fp16 is not supported in torchrun_main.py. Use deepspeed_main.py instead (but it seems to have bugs)")
Expand Down
90 changes: 76 additions & 14 deletions peft_pretraining/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,51 @@
import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import transformers

from loguru import logger


from peft_pretraining.modeling_llama import LlamaDecoderLayer


def initialize_fsdp(model, dtype):
wrapping_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
LlamaDecoderLayer,
},
)

if dtype in ["bf16", "bfloat16"]:
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16, # Gradient communication precision
buffer_dtype=torch.bfloat16, # Buffer precision
)
elif dtype == "float32":
mixed_precision_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32, # Gradient communication precision
buffer_dtype=torch.float32, # Buffer precision
)
else:
raise ValueError(f"Dtype {dtype} not supported (only float32 and bfloat16 are)")

model = FSDP(
model,
mixed_precision=mixed_precision_policy,
auto_wrap_policy=wrapping_policy,
)
return model


def get_scheculer(
optimizer,
*,
Expand Down Expand Up @@ -165,8 +205,8 @@ def _get_cosine_schedule_with_multiple_warmups_lambda(
"""
assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]"
assert restart_every > 0, "restart_every must be positive"
assert adjust_step + first_warmup_steps < num_training_steps, "warmup + adjust_step is more than full training steps"
assert adjust_step + first_warmup_steps < restart_every, "the first reset will happen before the warmup is done"
assert adjust_step + first_warmup_steps <= num_training_steps, "warmup + adjust_step is more than full training steps"
assert adjust_step + first_warmup_steps <= restart_every, "the first reset will happen before the warmup is done"

if current_step < first_warmup_steps:
return float(current_step) / float(max(1, first_warmup_steps))
Expand Down Expand Up @@ -213,7 +253,7 @@ def get_last_training_state(save_dir):
return None, None

model_dirs = sorted(model_dirs, key=lambda x: int(x.split("_")[-1]))
resume_from = model_dirs[-1]
resume_from = os.path.join(save_dir, model_dirs[-1])

logger.info(f"Restarting training from {resume_from}")
with open(os.path.join(resume_from, "training_state.json")) as f:
Expand Down Expand Up @@ -248,7 +288,9 @@ def optimizer_reset(
# pruning_fn has to be inplace to work with ZeroRedundancyOptimizer
if reset_optimizer_on_relora:
logger.info("Resetting optimizer states to zeros")
pruning_fn = partial(random_pruning_, prune_ratio=1)
# looks like zeroing out breaks dictionary in the optimizer
# see full error below
pruning_fn = partial(random_pruning_, prune_ratio=0.999)
elif optimizer_random_pruning:
logger.info(f"Performing random pruning of optimizer states. "
f"Pruning {optimizer_random_pruning} percent")
Expand All @@ -260,16 +302,6 @@ def optimizer_reset(
else:
raise ValueError("Unknown pruning type")

# import time
# import torch.distributed as dist
# rank = dist.get_rank()
# if rank == 1:
# time.sleep(5)
# print("*"*100)
# print(f"rank: {rank}")
# print(f"Optimizer state values for rank {rank}: {optimizer.optim.state.items()}")
# print("*"*100)

# ############################################################
# A reminder on how optimizer state is structured for regular optimizers:
# optimizer.state is a dict[torch.nn.Parameter, dict[str, torch.Tensor]]
Expand All @@ -281,10 +313,40 @@ def optimizer_reset(
# For ZeroRedundancyOptimizer, it works differently.
# ZeroRedundancyOptimizer.state always maps to empty dicts.
# Instead, it uses optimizer.optim.state for rank-local updates.
#
# For some reason, zeroing out a tensor in ZeroRedundancyOptimizer.opt.state
# causes an error during state_dict collection.
# This is why we use 0.999 pruning ratio for reset_optimizer case.
#
# Here's an error that happens:
#
# Traceback (most recent call last):
# File ".../peft_pretraining/torchrun_main.py", line 866, in <module>
# main(args)
# File ".../peft_pretraining/torchrun_main.py", line 715, in main
# save_model(
# File ".../peft_pretraining/torchrun_main.py", line 289, in save_model
# save_model_ddp(model, optimizer, scheduler, training_state_checkpoint, run_config, save_dir)
# File ".../peft_pretraining/torchrun_main.py", line 224, in save_model_ddp
# optimizer.consolidate_state_dict()
# File ".../python3.10/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 565, in consolidate_state_dict
# self.optim.state_dict(),
# File ".../python3.10/site-packages/torch/optim/optimizer.py", line 364, in state_dict
# packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
# File ".../python3.10/site-packages/torch/optim/optimizer.py", line 364, in <dictcomp>
# packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
# KeyError: 140580723685184
#
# One one hand, the hypothesis is that making a zero tensor
# is implementing by changing the pointer in the memory to
# an existing zero-tensor. But on the other hand, we didn't
# have issues with that when using regular Adam, without ZeroRedundancyOptimizer wrapper.
# ############################################################
n_zeros = 0
n_total = 0

from torch.optim.optimizer import Optimizer

optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
Expand Down
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from setuptools import setup

with open("requirements.txt") as f:
required = f.read().splitlines()

setup(
name="peft_pretraining",
version="1.0",
description="ReLoRA: Parameter-efficient pre-training",
url="https://github.com/Guitaricet/peft_pretraining",
author="Vlad Lialin",
author_email="[email protected]",
license="Apache 2.0",
packages=["peft_pretraining"],
install_requires=required,
)
Loading

0 comments on commit 7ac8e53

Please sign in to comment.