Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can overfitting lead to high-norm patches? #419

Closed
amundra15 opened this issue May 22, 2024 · 10 comments
Closed

Can overfitting lead to high-norm patches? #419

amundra15 opened this issue May 22, 2024 · 10 comments

Comments

@amundra15
Copy link

I want to finetune vitb14 on domain-specific data, and as a proof-of-concept, I am doing so on a fairly small dataset in the beginning. The resultant patch features show high-norm artefacts similar to the ones discussed in "Vision transformers need registers".

What confuses me is that the paper highlights that such artefacts were not noticed for vitb14 but only larger more-representative models. This makes me wonder if I am seeing those artefacts for vitb14 as a sign of model overfitting.

Any thoughts on this?

@heyoeyo
Copy link

heyoeyo commented May 22, 2024

There does seem to already be high-norm artifacts in vit-b (more info in issue #373), though they present a bit differently than the larger model. Specifically for vit-b, there's always (weirdly) a bunch of high norm tokens in the top-left patches.

I'm not sure about finetuning on small datasets, but the vit-b model was also used within Depth-Anything, which would've involved training on a large dataset for a different task, and it still shows similar artifacts. I'd guess that the artifacts you're seeing may just be the original ones, especially if they're concentrated in the top-left patches, and not directly related to finetuning.

@amundra15
Copy link
Author

Thanks for your response, @heyoeyo.

In my case, the artefacts appear along the left and top edges of the image (and not just the top-left corner). What is also interesting is that I am getting low norm values for the artefacts, but high values for the first principal component.

Input RGB:
input(1)

Norm of last layer patch tokens:
our_norm_7000iter

PCA(n=1) of last layer patch tokens:
our_fgbg

The values are overlayed in red. (ignore the mismatch in the image orientation).

I am not sure how to explain the low norm values for the artefacts.

@heyoeyo
Copy link

heyoeyo commented May 23, 2024

That seems very surprising! It's probably worth double checking if the original (not fine tuned) vit-b model produces similar artifacts (if you haven't already).

Another thing worth checking is whether the artifacts appear on earlier layers, and what that pattern looks like. In all cases I've seen, they aren't present on the earlier layers, but tend to appear and stay consistent on later layers, with the final layer being somewhat different from all others.
Not that this will explain your results specifically, but if you see a similar pattern it may be more of a hint that it's the same phenomena at least, even though you're getting low norm tokens.

@amundra15
Copy link
Author

The original model also produces similar low-norm artefacts (though not as evidently).

Norm of last layer patch tokens from official vitb14:
dino_norm

PCA(n=1) of last layer patch tokens from official vitb14:
dino_fgbg(1)

It's interesting to note that the original model shows regions of high as well as low norms. The fine-tuning is exacerbating the low-norm problem already present in the top-left region. Is this phenomenon studied and documented somewhere?

I will also add the visualizations from the other layers once I have them.

@heyoeyo
Copy link

heyoeyo commented May 28, 2024

Those results from the original model look a bit more similar to what I've seen, though it's strange that it seems inverted and that there are other tokens (not just the top-left) that seem out-of-place. It actually resembles the result from the larger models...
Are these norm results showing the patch tokens as-is, or does it also include the final layernorm step? If the layer norm is included, I wonder if that might explain the inversion of low/high norms at least?

As for visualizing the other layers, there's some code here which can at least give a qualitative result.

@amundra15
Copy link
Author

@heyoeyo you are right regarding the final layer norm. Upon commenting it out, I get high norms as expected:

Norm of last layer patch tokens from official vitb14:
ori_vitb14_patchnorm

Norm of last layer patch tokens from our fine-tuned vitb14:
ours_vitb14_featurenorm

The values are now similar to the ones discussed in the registers paper. However, I still observe artefacts along the entirety of the top and left edges.

I have a couple of queries regarding the impact on performance:

  1. Does this artefact somehow affect the cls token performance as well?
  2. What happens if I mask/threshold the patch token outputs for the artefacts? Can this ensure better downstream performance compared to using it as is?

@heyoeyo
Copy link

heyoeyo commented Jun 4, 2024

Does this artefact somehow affect the cls token performance as well?

I'd guess this depends a lot on how you're using the cls token. If you've trained another model to use the cls token (and included the vitb model in this training), then I'd imagine it's ok. The cls token has the chance to 'attend' to these weird high norm tokens throughout the model, so even if they include global info (as the registers paper suggests), end-to-end training involving the cls token should be able to account for this (to some extent) I think.
On the other hand, if you're attaching a separate model/other classifier onto the vitb cls token without further training of the vitb model, it may perform more poorly since there is nothing guiding the model to place the most relevant info into the cls token specifically.

What happens if I mask/threshold the patch token outputs for the artefacts? Can this ensure better downstream performance compared to using it as is?

I think this again depends on how the model is used. If the downstream processing is trained in conjunction with the vit output, then it's likely to outperform any hand-picked mask/thresholding settings (given how hard it would be to do this with these odd patterns).
The Depth-Anything models have a substantial amount of post-processing after the vit encoding and seem to perform fine at least. Though of course there's always a chance that it could be even better if not for these weird patterns, though I think the only way to know that is to try with different models (especially the ones with registers in this case).

@amundra15
Copy link
Author

A short update regarding the issue:
The issue is related to training instability as well (somewhat expected). After making a modification to our (custom) loss function, we observe more stable training. This leads to no high norm patches in the feature space, as well as better downstream performance.

@Supersak80
Copy link

@amundra15 would you please comment on how you modified your loss function and why/how that leads to more stable training? thank you!

@amundra15
Copy link
Author

I have made a couple of major changes to the loss function tailored to the problem at hand. One of them was difficult to optimize (and also did not make sense intuitively), leading to increasing loss values for the original losses in the paper. Upon removing that loss, the training became more stable and the original losses also decreased consistently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants