Skip to content

Commit

Permalink
Notebook for attention map over input image
Browse files Browse the repository at this point in the history
  • Loading branch information
Lance Legel committed Nov 12, 2023
1 parent da4b382 commit df7265c
Show file tree
Hide file tree
Showing 4 changed files with 556 additions and 11 deletions.
14 changes: 9 additions & 5 deletions dinov2/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: Tensor, return_attn=False) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

Expand All @@ -66,15 +66,19 @@ def forward(self, x: Tensor) -> Tensor:
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

if return_attn:
return attn

return x



class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
def forward(self, x: Tensor, attn_bias=None, return_attn=False) -> Tensor:
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x, return_attn)

B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
Expand Down
16 changes: 10 additions & 6 deletions dinov2/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ def __init__(

self.sample_drop_ratio = drop_path

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: Tensor, return_attention=False) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))

def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))


if return_attention:
return self.attn(self.norm1(x), return_attn=True)

if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
Expand All @@ -114,6 +117,7 @@ def ffn_residual_func(x: Tensor) -> Tensor:
return x



def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
Expand Down Expand Up @@ -249,12 +253,12 @@ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
x = x + ffn_residual_func(x)
return attn_bias.split(x)

def forward(self, x_or_x_list):
def forward(self, x_or_x_list, return_attention=False):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
return super().forward(x_or_x_list, return_attention)
elif isinstance(x_or_x_list, list):
if not XFORMERS_AVAILABLE:
raise AssertionError("xFormers is required for using nested tensors")
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError

13 changes: 13 additions & 0 deletions dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,19 @@ def get_intermediate_layers(
return tuple(zip(outputs, class_tokens))
return tuple(outputs)

def get_last_self_attention(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)

x = self.prepare_tokens_with_masks(x, masks)

# Run through model, at the last block just return the attention.
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
return blk(x, return_attention=True)

def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
Expand Down
524 changes: 524 additions & 0 deletions notebooks/attention.ipynb

Large diffs are not rendered by default.

0 comments on commit df7265c

Please sign in to comment.