Skip to content

Commit

Permalink
add lora training
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweiy committed May 26, 2024
1 parent 50a0db8 commit 3f3a60e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ conda activate dmd2

pip install --upgrade anyio
pip install torch==2.0.1 torchvision==0.15.2
pip install --upgrade diffusers wandb lmdb transformers accelerate==0.23.0 lmdb datasets evaluate scipy opencv-python matplotlib imageio piq==0.7.0 safetensors gradio
pip install --upgrade diffusers peft wandb lmdb transformers accelerate==0.23.0 lmdb datasets evaluate scipy opencv-python matplotlib imageio piq==0.7.0 safetensors gradio
python setup.py develop
```

Expand Down
30 changes: 29 additions & 1 deletion main/sd_unified_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# A single unified model that wraps both the generator and discriminator
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from diffusers import UNet2DConditionModel, AutoencoderKL, AutoencoderTiny
from main.utils import get_x0_from_noise, NoOpContext
from main.sdxl.sdxl_text_encoder import SDXLTextEncoder
Expand Down Expand Up @@ -40,7 +41,34 @@ def __init__(self, args, accelerator):
subfolder="unet"
).float()

self.feedforward_model.requires_grad_(True)
if args.generator_lora:
self.feedforward_model.requires_grad_(False)
assert args.sdxl
lora_target_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"conv1",
"conv2",
"conv_shortcut",
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
]
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout
)
self.feedforward_model.add_adapter(lora_config)
else:
self.feedforward_model.requires_grad_(True)

if self.gradient_checkpointing:
self.feedforward_model.enable_gradient_checkpointing()
Expand Down
5 changes: 5 additions & 0 deletions main/train_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,11 @@ def parse_args():
parser.add_argument("--gan_alone", action="store_true", help="only use the gan loss without dmd")
parser.add_argument("--backward_simulation", action="store_true")

parser.add_argument("--generator_lora", action="store_true")
parser.add_argument("--lora_rank", type=int, default=64)
parser.add_argument("--lora_alpha", type=float, default=8)
parser.add_argument("--lora_dropout", type=float, default=0.0)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
Expand Down

0 comments on commit 3f3a60e

Please sign in to comment.