-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
49 lines (37 loc) · 1.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
from torchvision import datasets
from augment.handler import ModelHandler
from augment.utils import Utils
from augment.diffuseMix import DiffuseMix
def parse_arguments():
parser = argparse.ArgumentParser(description="Generate an augmented dataset from original images and fractal patterns.")
parser.add_argument('--train_dir', type=str, required=True, help='Path to the directory containing the original training images.')
parser.add_argument('--fractal_dir', type=str, required=True, help='Path to the directory containing the fractal images.')
parser.add_argument('--prompts', type=str, required=True, help='Comma-separated list of prompts for image generation.')
return parser.parse_args()
def main():
args = parse_arguments()
prompts = args.prompts.split(',') # This will give you a list of prompts
# Initialize the model
model_id = "timbrooks/instruct-pix2pix"
model_initialization = ModelHandler(model_id=model_id, device='cuda')
# Load the original dataset
train_dataset = datasets.ImageFolder(root=args.train_dir)
idx_to_class = {v: k for k, v in train_dataset.class_to_idx.items()}
# Load fractal images
fractal_imgs = Utils.load_fractal_images(args.fractal_dir)
# Create the augmented dataset
augmented_train_dataset = DiffuseMix(
original_dataset=train_dataset,
fractal_imgs=fractal_imgs,
num_images=1,
guidance_scale=4,
idx_to_class = idx_to_class,
prompts=prompts,
model_handler=model_initialization
)
for idx, (image, label) in enumerate(augmented_train_dataset):
image.save(f'augmented_images/{idx}.png')
pass
if __name__ == '__main__':
main()