-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisual_prompt.py
108 lines (90 loc) · 4.16 KB
/
visual_prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class ExpansiveVisualPrompt(nn.Module):
def __init__(self, args, normalize=None):
print('prompt method: expand\n')
super(ExpansiveVisualPrompt, self).__init__()
output_size = args.output_size
input_size = args.input_size
mask = torch.zeros(3, input_size, input_size)
self.l_pad = int((output_size-input_size+1)/2)
self.r_pad = int((output_size-input_size)/2)
self.register_buffer("mask", F.pad(mask, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=1))
self.program = torch.nn.Parameter(data=torch.zeros(3, output_size, output_size))
self.normalize = normalize
def forward(self, x):
x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(self.program) * self.mask
x = x.clamp(0, 1)
if self.normalize is not None:
x = self.normalize(x)
return x
class PadVisualPrompt(nn.Module):
def __init__(self, args, normalize=None):
print('prompt method: pad\n')
super(PadVisualPrompt, self).__init__()
pad = args.pad_size
output_size = args.output_size
input_size = args.input_size
self.l_pad = int((output_size-input_size+1)/2)
self.r_pad = int((output_size-input_size)/2)
self.normalize=normalize
self.program = torch.nn.Parameter(data=torch.zeros(3, output_size, output_size))
if output_size > 2*pad:
mask = torch.zeros(3, output_size-2*pad, output_size-2*pad)
self.register_buffer("mask", F.pad(mask, [pad for _ in range(4)], value=1))
elif output_size == 2*pad:
mask = torch.ones(3, output_size, output_size)
self.register_buffer("mask", mask)
else:
raise ValueError("Pad Should Not Exceed Half Of Output Size")
def forward(self, x):
x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + self.program * self.mask
x = x.clamp(0, 1)
if self.normalize is not None:
x = self.normalize(x)
return x
class FixVisualPrompt(nn.Module):
def __init__(self, args, normalize):
print('prompt method: fix\n')
super(FixVisualPrompt, self).__init__()
mask_size = args.mask_size
output_size = args.output_size
input_size = args.input_size
self.l_pad = int((output_size-input_size+1)/2)
self.r_pad = int((output_size-input_size)/2)
mask = torch.zeros(3, output_size, output_size)
mask[:, :mask_size, :mask_size] = 1
self.register_buffer("mask", mask)
self.program = torch.nn.Parameter(data=torch.zeros(3, output_size, output_size))
self.normalize = normalize
def forward(self, x):
x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(self.program) * self.mask
x = x.clamp(0, 1)
if self.normalize is not None:
x = self.normalize(x)
return x
class RandomVisualPrompt(nn.Module):
def __init__(self, args, normalize):
print('prompt method: random\n')
super(RandomVisualPrompt, self).__init__()
output_size = args.output_size
input_size = args.input_size
self.mask_size = args.mask_size
self.output_size = output_size
self.input_size = input_size
self.l_pad = int((output_size-input_size+1)/2)
self.r_pad = int((output_size-input_size)/2)
self.program = torch.nn.Parameter(data=torch.zeros(3, output_size, output_size))
self.normalize = normalize
def forward(self, x):
mask = torch.zeros(3, self.output_size, self.output_size)
x_ = np.random.choice(self.output_size - self.mask_size)
y_ = np.random.choice(self.output_size - self.mask_size)
mask[:, x_ : x_ + self.mask_size, y_ : y_ + self.mask_size] = 1
x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(self.program) * (mask.cuda())
x = x.clamp(0, 1)
if self.normalize is not None:
x = self.normalize(x)
return x