Skip to content

Commit

Permalink
unpaired training
Browse files Browse the repository at this point in the history
  • Loading branch information
GaParmar committed Apr 10, 2024
1 parent 60ba021 commit db6cf92
Show file tree
Hide file tree
Showing 14 changed files with 1,242 additions and 181 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ We tightly integrate three separate modules in the original latent diffusion mod
**Unpaired Image Translation (CycleGAN-Turbo)**
- The following command takes a **day** image file as input, and saves the output **night** in the directory specified.
```
python src/inference_unpaired.py --model "day_to_night" \
python src/inference_unpaired.py --model_name "day_to_night" \
--input_image "assets/examples/day2night_input.png" --output_dir "outputs"
```
<table>
Expand All @@ -158,7 +158,7 @@ We tightly integrate three separate modules in the original latent diffusion mod
- The following command takes a **night** image file as input, and saves the output **day** in the directory specified.
```
python src/inference_unpaired.py --model "night_to_day" \
python src/inference_unpaired.py --model_name "night_to_day" \
--input_image "assets/examples/night2day_input.png" --output_dir "outputs"
```
<table>
Expand All @@ -173,7 +173,7 @@ We tightly integrate three separate modules in the original latent diffusion mod
- The following command takes a **clear** image file as input, and saves the output **rainy** in the directory specified.
```
python src/inference_unpaired.py --model "clear_to_rainy" \
python src/inference_unpaired.py --model_name "clear_to_rainy" \
--input_image "assets/examples/clear2rainy_input.png" --output_dir "outputs"
```
<table>
Expand All @@ -188,7 +188,7 @@ We tightly integrate three separate modules in the original latent diffusion mod
- The following command takes a **rainy** image file as input, and saves the output **clear** in the directory specified.
```
python src/inference_unpaired.py --model "rainy_to_clear" \
python src/inference_unpaired.py --model_name "rainy_to_clear" \
--input_image "assets/examples/rainy2clear_input.png" --output_dir "outputs"
```
<table>
Expand Down
Binary file added assets/examples/my_horse2zebra_input.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/my_horse2zebra_output.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/training_evaluation_unpaired.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
90 changes: 89 additions & 1 deletion docs/training_cyclegan_turbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,95 @@ We will use the [horse2zebra dataset](https://github.com/junyanz/pytorch-CycleGA


### Step 1. Get the Dataset
- First download the horse2zebra dataset from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip).
- First download the horse2zebra dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip) using the command below.
```
bash scripts/download_horse2zebra.sh
```
- Our training scripts expect the dataset to be in the following format:
```
data
├── dataset_name
│ ├── train_A
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ ├── train_B
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ └── fixed_prompt_a.txt
| └── fixed_prompt_b.txt
|
| ├── test_A
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ ├── test_B
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
```
- The `fixed_prompt_a.txt` and `fixed_prompt_b.txt` files contain the **fixed caption** used for the source and target domains respectively.
### Step 2. Train the Model
- Initialize the `accelerate` environment with the following command:
```
accelerate config
```
- Run the following command to train the model.
```
export NCCL_P2P_DISABLE=1
accelerate launch --main_process_port 29501 src/train_cyclegan_turbo.py \
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
--output_dir="output/cyclegan_turbo/my_horse2zebra" \
--dataset_folder "data/my_horse2zebra" \
--train_img_prep "resize_286_randomcrop_256x256_hflip" --val_img_prep "no_resize" \
--learning_rate="1e-5" --max_train_steps=25000 \
--train_batch_size=1 --gradient_accumulation_steps=1 \
--report_to "wandb" --tracker_project_name "gparmar_unpaired_h2z_cycle_debug_v2" \
--enable_xformers_memory_efficient_attention --validation_steps 250 \
--lambda_gan 0.5 --lambda_idt 1 --lambda_cycle 1
```
- Additional optional flags:
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
### Step 3. Monitor the training progress
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
<div>
<p align="center">
<img src='../assets/examples/training_evaluation.png' align="center" width=800px>
</p>
</div>
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
### Step 4. Running Inference with the trained models
- You can run inference using the trained model using the following command:
```
python src/inference_unpaired.py --model_path "output/cyclegan_turbo/my_horse2zebra/checkpoints/model_1001.pkl" \
--input_image "data/my_horse2zebra/test_A/n02381460_20.jpg" \
--prompt "picture of a zebra" --direction "a2b" \
--output_dir "outputs" --image_prep "no_resize"
```
- The above command should generate the following output:
<table>
<tr>
<th>Model Input</th>
<th>Model Output</th>
</tr>
<tr>
<td><img src='../assets/examples/my_horse2zebra_input.jpg' width="200px"></td>
<td><img src='../assets/examples/my_horse2zebra_output.jpg' width="200px"></td>
</tr>
</table>
5 changes: 5 additions & 0 deletions scripts/download_horse2zebra.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mkdir -p data
wget https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip -O data/my_horse2zebra.zip
cd data
unzip my_horse2zebra.zip
rm my_horse2zebra.zip
150 changes: 135 additions & 15 deletions src/cyclegan_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
p = "src/"
sys.path.append(p)
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
Expand Down Expand Up @@ -44,8 +45,69 @@ def forward(self, x, direction):
return x_decoded


def initialize_unet(rank, return_lora_module_names=False):
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
unet.requires_grad_(False)
unet.train()
l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
for n, p in unet.named_parameters():
if "bias" in n or "norm" in n: continue
for pattern in l_grep:
if pattern in n and ("down_blocks" in n or "conv_in" in n):
l_target_modules_encoder.append(n.replace(".weight",""))
break
elif pattern in n and "up_blocks" in n:
l_target_modules_decoder.append(n.replace(".weight",""))
break
elif pattern in n:
l_modules_others.append(n.replace(".weight",""))
break
lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
unet.add_adapter(lora_conf_others, adapter_name="default_others")
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
if return_lora_module_names:
return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
else:
return unet


def initialize_vae(rank=4, return_lora_module_names=False):
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
vae.requires_grad_(False)
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
vae.requires_grad_(True)
vae.train()
# add the skip connection convs
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
vae.decoder.ignore_skip = False
vae.decoder.gamma = 1
l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
"conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
"skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
]
vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
if return_lora_module_names:
return vae, l_vae_target_modules
else:
return vae


class CycleGAN_Turbo(torch.nn.Module):
def __init__(self, pretrained_name, ckpt_folder="checkpoints"):
def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
Expand Down Expand Up @@ -85,15 +147,19 @@ def __init__(self, pretrained_name, ckpt_folder="checkpoints"):
self.timesteps = torch.tensor([999], device="cuda").long()
self.caption = "driving in the day"
self.direction = "b2a"

elif pretrained_path is not None:
sd = torch.load(pretrained_path)
self.load_ckpt_from_state_dict(sd)
self.timesteps = torch.tensor([999], device="cuda").long()
self.caption = None
self.direction = None

self.vae_enc.cuda()
self.vae_dec.cuda()
self.unet.cuda()

def load_ckpt_from_url(self, url, ckpt_folder):
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, os.path.basename(url))
download_url(url, outf)
sd = torch.load(outf)
def load_ckpt_from_state_dict(self, sd):
lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
Expand Down Expand Up @@ -123,12 +189,66 @@ def load_ckpt_from_url(self, url, ckpt_folder):
self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
self.vae_dec.load_state_dict(sd["sd_vae_dec"])

def forward(self, x_t):
caption_tokens = self.tokenizer(self.caption, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]
x_t_enc = self.vae_enc(x_t, direction=self.direction)
model_pred = self.unet(x_t_enc, self.timesteps, encoder_hidden_states=caption_enc,).sample
x_denoised = self.sched.step(model_pred, self.timesteps, x_t_enc, return_dict=True).prev_sample
output = self.vae_dec(x_denoised, direction=self.direction)
return output
def load_ckpt_from_url(self, url, ckpt_folder):
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, os.path.basename(url))
download_url(url, outf)
sd = torch.load(outf)
self.load_ckpt_from_state_dict(sd)

@staticmethod
def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
B = x.shape[0]
assert direction in ["a2b", "b2a"]
x_enc = vae_enc(x, direction=direction).to(x.dtype)
model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
x_out_decoded = vae_dec(x_out, direction=direction)
return x_out_decoded

@staticmethod
def get_traininable_params(unet, vae_a2b, vae_b2a):
# add all unet parameters
params_gen = list(unet.conv_in.parameters())
unet.conv_in.requires_grad_(True)
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
for n,p in unet.named_parameters():
if "lora" in n and "default" in n:
assert p.requires_grad
params_gen.append(p)

# add all vae_a2b parameters
for n,p in vae_a2b.named_parameters():
if "lora" in n and "vae_skip" in n:
assert p.requires_grad
params_gen.append(p)
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())

# add all vae_b2a parameters
for n,p in vae_b2a.named_parameters():
if "lora" in n and "vae_skip" in n:
assert p.requires_grad
params_gen.append(p)
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
return params_gen

def forward(self, x_t, direction=None, caption=None, caption_emb=None):
if direction is None:
assert self.direction is not None
direction = self.direction
if caption is None and caption_emb is None:
assert self.caption is not None
caption = self.caption
if caption_emb is not None:
caption_enc = caption_emb
else:
caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
3 changes: 2 additions & 1 deletion src/inference_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
if args.model_name == '' != args.model_path == '':
raise ValueError('Either model_name or model_path should be provided')

os.makedirs(args.output_dir, exist_ok=True)

# initialize the model
model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
model.set_eval()
Expand Down Expand Up @@ -60,5 +62,4 @@
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)

# save the output image
os.makedirs(args.output_dir, exist_ok=True)
output_pil.save(os.path.join(args.output_dir, bname))
33 changes: 23 additions & 10 deletions src/inference_unpaired.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,45 @@
import torch
from torchvision import transforms
from cyclegan_turbo import CycleGAN_Turbo
from my_utils.training_utils import build_transform


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
parser.add_argument('--model_name', type=str, default='day_to_night', help='name of the model to be used')
parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
args = parser.parse_args()

# only one of model_name and model_path should be provided
if args.model_name is None != args.model_path is None:
raise ValueError('Either model_name or model_path should be provided')

if args.model_path is not None and args.prompt is None:
raise ValueError('prompt is required when loading a custom model_path.')

if args.model_name is not None:
assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
assert args.direction is None, 'direction is not required when loading a pretrained model.'

# initialize the model
model = CycleGAN_Turbo(pretrained_name=args.model_name)
model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
model.eval()
model.unet.enable_xformers_memory_efficient_attention()

if args.image_prep == "resize_512x512":
T_val = transforms.Compose([
transforms.Resize((512, 512), interpolation=Image.LANCZOS),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
T_val = build_transform(args.image_prep)

input_image = Image.open(args.input_image).convert('RGB')
# translate the image
with torch.no_grad():
x_t = T_val(input_image).unsqueeze(0).cuda()
output = model(x_t)
input_img = T_val(input_image)
x_t = transforms.ToTensor()(input_img)
x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
output = model(x_t, direction=args.direction, caption=args.prompt)

output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
Expand Down
Loading

0 comments on commit db6cf92

Please sign in to comment.