Skip to content

Commit

Permalink
paired training code
Browse files Browse the repository at this point in the history
  • Loading branch information
GaParmar committed Mar 24, 2024
1 parent 51c98c8 commit f1a50d5
Show file tree
Hide file tree
Showing 25 changed files with 854 additions and 64 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ checkpoints/
img2img-turbo-sketch
outputs/
outputs/bird.png
data
wandb
output/
61 changes: 53 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,72 @@ We tightly integrate three separate modules in the original latent diffusion mod
**Paired Image Translation (pix2pix-turbo)**
- The following command takes an image file and a prompt as inputs, extracts the canny edges, and saves the results in the directory specified.
```bash
python src/inference_paired.py --model "edge_to_image" \
--input_image "assets/bird.png" \
python src/inference_paired.py --model_name "edge_to_image" \
--input_image "assets/examples/bird.png" \
--prompt "a blue bird" \
--output_dir "outputs"
```
<table>
<th>Input Image</th>
<th>Canny Edges</th>
<th>Model Output</th>
</tr>
<tr>
<td><img src='assets/examples/bird.png' width="200px"></td>
<td><img src='assets/examples/bird_canny.png' width="200px"></td>
<td><img src='assets/examples/bird_canny_blue.png' width="200px"></td>
</tr>
</table>
<br>
- The following command takes a sketch and a prompt as inputs, and saves the results in the directory specified.
```bash
python src/inference_paired.py --model "sketch_to_image_stochastic" \
--input_image "assets/sketch.png" --gamma 0.4 \
python src/inference_paired.py --model_name "sketch_to_image_stochastic" \
--input_image "assets/examples/sketch_input.png" --gamma 0.4 \
--prompt "ethereal fantasy concept art of an asteroid. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy" \
--output_dir "outputs"
```
<table>
<th>Input</th>
<th>Model Output</th>
</tr>
<tr>
<td><img src='assets/examples/sketch_input.png' width="400px"></td>
<td><img src='assets/examples/sketch_output.png' width="400px"></td>
</tr>
</table>
<br>
**Unpaired Image Translation (CycleGAN-Turbo)**
- The following command takes an image file as input, and saves the results in the directory specified.
```bash
- The following command takes a **day** image file as input, and saves the output **night** in the directory specified.
```
python src/inference_unpaired.py --model "day_to_night" \
--input_image "assets/day.png" --output_dir "outputs"
--input_image "assets/examples/day2night_input.png" --output_dir "outputs"
```
<table>
<th>Input (day)</th>
<th>Model Output (night)</th>
</tr>
<tr>
<td><img src='assets/examples/day2night_input.png' width="400px"></td>
<td><img src='assets/examples/day2night_output.png' width="400px"></td>
</tr>
</table>
- The following command takes a **night** image file as input, and saves the output **day** in the directory specified.
```
python src/inference_unpaired.py --model "night_to_day" \
--input_image "assets/examples/night2day_input.png" --output_dir "outputs"
```
<table>
<th>Input (night)</th>
<th>Model Output (day)</th>
</tr>
<tr>
<td><img src='assets/examples/night2day_input.png' width="400px"></td>
<td><img src='assets/examples/night2day_output.png' width="400px"></td>
</tr>
</table>
## Gradio Demo
- We provide a Gradio demo for the paired image translation tasks.
Expand Down
File renamed without changes
Binary file added assets/examples/bird_canny.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/bird_canny_blue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/circles_inference_input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/circles_inference_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
Binary file added assets/examples/day2night_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/night2day_input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/night2day_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
Binary file added assets/examples/sketch_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/training_evaluation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/training_step_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/training_step_500.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/examples/training_step_6000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
118 changes: 118 additions & 0 deletions docs/training_pix2pix_turbo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
## Training with Paired Data (pix2pix-turbo)
Here, we show how to train a pix2pix-turbo model using paired data.
We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.


### Step 1. Get the Dataset
- First download the Fill50k dataset from [here](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip).
```
bash scripts/download_fill50k.sh
```
- Our training scripts expect the dataset to be in the following format:
```
data
├── dataset_name
│ ├── train_A
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ ├── train_B
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ └── train_prompts.json
|
| ├── test_A
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ ├── test_B
│ │ ├── 000000.png
│ │ ├── 000001.png
│ │ └── ...
│ └── test_prompts.json
```
### Step 2. Train the Model
- Initialize the `accelerate` environment with the following command:
```
accelerate config
```
- Run the following command to train the model.
```
accelerate launch src/train_pix2pix_turbo.py \
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
--output_dir="output/pix2pix_turbo/fill50k" \
--dataset_folder="data/my_fill50k" \
--resolution=512 \
--train_batch_size=2 \
--enable_xformers_memory_efficient_attention --viz_freq 25 \
--track_val_fid \
--report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
```
- Additional optional flags:
- `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
- `--viz_freq`: Frequency of visualizing the results during training.
### Step 3. Monitor the training progress
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
<div>
<p align="center">
<img src='../assets/examples/training_evaluation.png' align="center" width=800px>
</p>
</div>
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
- Screenshots of the training progress are shown below:
- Step 0:
<div>
<p align="center">
<img src='../assets/examples/training_step_0.png' align="center" width=800px>
</p>
</div>
- Step 500:
<div>
<p align="center">
<img src='../assets/examples/training_step_500.png' align="center" width=800px>
</p>
</div>
- Step 6000:
<div>
<p align="center">
<img src='../assets/examples/training_step_6000.png' align="center" width=800px>
</p>
</div>
### Step 4. Running Inference with the trained models
- You can run inference using the trained model using the following command:
```
python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
--input_image "data/my_fill50k/test_A/40000.png" \
--prompt "violet circle with orange background" \
--output_dir "outputs"
```
- The above command should generate the following output:
<table>
<tr>
<th>Model Input</th>
<th>Model Output</th>
</tr>
<tr>
<td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
<td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
</tr>
</table>
5 changes: 5 additions & 0 deletions scripts/download_fill50k.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mkdir -p data
wget https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip -O data/my_fill50k.zip
cd data
unzip my_fill50k.zip
rm my_fill50k.zip
73 changes: 67 additions & 6 deletions src/cyclegan_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def forward(self, x, direction):


class CycleGAN_Turbo(torch.nn.Module):
def __init__(self, name, ckpt_folder="checkpoints"):
def __init__(self, pretrained_name, ckpt_folder="checkpoints"):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
self.sched = make_1step_sched()
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")

if name == "day_to_night":
if pretrained_name == "day_to_night":
# download the checkopoint from the url
url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
os.makedirs(ckpt_folder, exist_ok=True)
Expand Down Expand Up @@ -113,19 +113,80 @@ def __init__(self, name, ckpt_folder="checkpoints"):
vae_dec.load_state_dict(sd["sd_vae_dec"])
self.timesteps = torch.tensor([999], device="cuda").long()
self.caption = "driving in the night"
self.direction = "a2b"

elif pretrained_name == "night_to_day":
# download the checkopoint from the url
url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, "night2day.pkl")
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")

sd = torch.load(outf)
lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
unet.add_adapter(lora_conf_others, adapter_name="default_others")
for n, p in unet.named_parameters():
name_sd = n.replace(".default_encoder.weight", ".weight")
if "lora" in n and "default_encoder" in n:
p.data.copy_(sd["sd_encoder"][name_sd])
for n, p in unet.named_parameters():
name_sd = n.replace(".default_decoder.weight", ".weight")
if "lora" in n and "default_decoder" in n:
p.data.copy_(sd["sd_decoder"][name_sd])
for n, p in unet.named_parameters():
name_sd = n.replace(".default_others.weight", ".weight")
if "lora" in n and "default_others" in n:
p.data.copy_(sd["sd_other"][name_sd])
unet.set_adapter(["default_encoder", "default_decoder", "default_others"])

vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.ignore_skip = False
vae_lora_config = LoraConfig(r=4, init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
vae.decoder.gamma = 1
vae_b2a = copy.deepcopy(vae)
vae_enc = VAE_encode(vae, vae_b2a=vae_b2a)
vae_enc.load_state_dict(sd["sd_vae_enc"])
vae_dec = VAE_decode(vae, vae_b2a=vae_b2a)
vae_dec.load_state_dict(sd["sd_vae_dec"])
self.timesteps = torch.tensor([999], device="cuda").long()
self.caption = "driving in the day"
self.direction = "b2a"

vae_enc.cuda()
vae_dec.cuda()
unet.cuda()
unet.enable_xformers_memory_efficient_attention()
self.unet, self.vae_enc, self.vae_dec = unet, vae_enc, vae_dec

def forward(self, x_t, direction):
assert direction in ["a2b", "b2a"]
def forward(self, x_t):
caption_tokens = self.tokenizer(self.caption, max_length=self.tokenizer.model_max_length,
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]
x_t_enc = self.vae_enc(x_t, direction=direction)
x_t_enc = self.vae_enc(x_t, direction=self.direction)
model_pred = self.unet(x_t_enc, self.timesteps, encoder_hidden_states=caption_enc,).sample
x_denoised = self.sched.step(model_pred, self.timesteps, x_t_enc, return_dict=True).prev_sample
output = self.vae_dec(x_denoised, direction=direction)
output = self.vae_dec(x_denoised, direction=self.direction)
return output
23 changes: 18 additions & 5 deletions src/inference_paired.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import argparse
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
Expand All @@ -11,41 +12,53 @@
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
parser.add_argument('--model_name', type=str, default='edge_to_image', help='name of the model to be used')
parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
args = parser.parse_args()

# only one of model_name and model_path should be provided
if args.model_name == '' != args.model_path == '':
raise ValueError('Either model_name or model_path should be provided')

# initialize the model
model = Pix2Pix_Turbo(args.model_name)
model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
model.set_eval()

# make sure that the input image is a multiple of 8
input_image = Image.open(args.input_image).convert('RGB')
new_width = input_image.width - input_image.width % 8
new_height = input_image.height - input_image.height % 8
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
bname = os.path.basename(args.input_image)

# translate the image
with torch.no_grad():
if args.model_name == 'edge_to_image':
canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
canny_viz_inv = Image.fromarray(255 - np.array(canny))
canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
c_t = F.to_tensor(canny).unsqueeze(0).cuda()
output_image = model(c_t, args.prompt)

if args.model_name == 'sketch_to_image_stochastic':
elif args.model_name == 'sketch_to_image_stochastic':
image_t = F.to_tensor(input_image) < 0.5
c_t = image_t.unsqueeze(0).cuda().float()
torch.manual_seed(args.seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)

else:
c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
output_image = model(c_t, args.prompt)

output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)

# save the output image
bname = os.path.basename(args.input_image)
os.makedirs(args.output_dir, exist_ok=True)
output_pil.save(os.path.join(args.output_dir, bname))
Loading

0 comments on commit f1a50d5

Please sign in to comment.