-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDRM.py
144 lines (122 loc) · 4.17 KB
/
DRM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch.nn as nn
import torch
import torch.nn.functional as F
from guided_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
)
from pdb import set_trace as st
class Args:
image_size=256
num_channels=256
num_res_blocks=2
num_heads=4
num_heads_upsample=-1
num_head_channels=64
attention_resolutions="32,16,8"
channel_mult=""
dropout=0.0
class_cond=False
use_checkpoint=False
use_scale_shift_norm=True
resblock_updown=True
use_fp16=False
use_new_attention_order=False
clip_denoised=True
num_samples=10000
batch_size=16
use_ddim=False
model_path=""
classifier_path=""
classifier_scale=1.0
learn_sigma=True
diffusion_steps=1000
noise_schedule="linear"
timestep_respacing=None
use_kl=False
predict_xstart=False
rescale_timesteps=False
rescale_learned_sigmas=False
class DiffusionRobustModel(nn.Module):
def __init__(self, classifier, noise_sd, num_noise_vec=40, no_diffusion=False, dataset="trojai"):
super().__init__()
model, diffusion = create_model_and_diffusion(
**args_to_dict(Args(), model_and_diffusion_defaults().keys())
)
model.load_state_dict(
# torch.load("/home/mingjies/projects/diffusion_robustness/imagenet/256x256_diffusion_uncond.pt")
torch.load("weights/256x256_diffusion_uncond.pt")
)
model.eval().cuda()
model = torch.nn.DataParallel(model).cuda()
classifier = torch.nn.DataParallel(classifier).cuda()
self.model = model
self.diffusion = diffusion
self.classifier = classifier
self.dataset = dataset
self.no_diffusion = no_diffusion
self.num_noise_vec = num_noise_vec
## compute the timestep t corresponding to the added noise level according to https://arxiv.org/abs/2206.10550
real_sigma = 0
t = 0
while real_sigma < noise_sd * 2:
t += 1
a = diffusion.sqrt_alphas_cumprod[t]
b = diffusion.sqrt_one_minus_alphas_cumprod[t]
real_sigma = b / a
self.sigma = noise_sd
self.t = t
print("t found for sigma %.2f: %d"%(noise_sd, t))
def reset_sigma(self, noise_sd):
real_sigma = 0
t = 0
while real_sigma < noise_sd * 2:
t += 1
a = self.diffusion.sqrt_alphas_cumprod[t]
b = self.diffusion.sqrt_one_minus_alphas_cumprod[t]
real_sigma = b / a
self.sigma = noise_sd
self.t = t
print("reset t for sigma %.2f: %d"%(noise_sd, t))
def forward(self, x):
x = x.repeat((self.num_noise_vec,1,1,1))
if self.no_diffusion: # w/o diffusion
x += torch.randn_like(x) * self.sigma
else: # w diffusion
x = x * 2 - 1
x = self.diffusion_denoise(x, self.t)
x = (x+1)/2
out = self.classifier(x)
if self.num_noise_vec == 1 and x.shape[0] != 1:
return out
out = F.softmax(out, dim=1)
out = torch.mean(out, dim=0, keepdims=True)
return out
def diffusion_denoise(self, x_start, t, multistep=False):
t_batch = torch.tensor([t] * len(x_start)).cuda()
noise = torch.randn_like(x_start)
# Gaussian noise is added at this step
x_t_start = self.diffusion.q_sample(x_start=x_start, t=t_batch, noise=noise)
with torch.enable_grad():
# with torch.no_grad():
if multistep:
out = x_t_start
for i in range(t)[::-1]:
print(i)
t_batch = torch.tensor([i] * len(x_start)).cuda()
out = self.diffusion.p_sample(
self.model,
out,
t_batch,
clip_denoised=True
)['sample']
else:
out = self.diffusion.p_sample(
self.model,
x_t_start,
t_batch,
clip_denoised=True
)['pred_xstart']
return out