-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 8dcefbe
Showing
40 changed files
with
5,426 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.pyc | ||
*.out | ||
**/__pycache__/ | ||
# !plots |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "model_symmetries/ct"] | ||
path = model_symmetries/ct | ||
url = https://github.com/SHI-Labs/Compact-Transformers.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
This material was prepared as an account of work sponsored by an agency of the | ||
United States Government. Neither the United States Government nor the United | ||
States Department of Energy, nor Battelle, nor any of their employees, nor any | ||
jurisdiction or organization that has cooperated in the development of these | ||
materials, makes any warranty, express or implied, or assumes any legal | ||
liability or responsibility for the accuracy, completeness, or usefulness or any | ||
information, apparatus, product, software, or process disclosed, or represents | ||
that its use would not infringe privately owned rights. Reference herein to any | ||
specific commercial product, process, or service by trade name, trademark, | ||
manufacturer, or otherwise does not necessarily constitute or imply its | ||
endorsement, recommendation, or favoring by the United States Government or any | ||
agency thereof, or Battelle Memorial Institute. The views and opinions of | ||
authors expressed herein do not necessarily state or reflect those of the United | ||
States Government or any agency thereof. | ||
|
||
PACIFIC NORTHWEST NATIONAL LABORATORY | ||
operated by | ||
BATTELLE | ||
for the | ||
UNITED STATES DEPARTMENT OF ENERGY | ||
under Contract DE-AC05-76RL01830 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
Simplified BSD | ||
____________________________________________ | ||
Copyright 2022 Battelle Memorial Institute | ||
|
||
Redistribution and use in source and binary forms, with or without modification, | ||
are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR | ||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON | ||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
Code to run the experiments of the Neurips 2022 paper [On the Symmetries of Deep | ||
Learning Models and their Internal | ||
Representations](https://arxiv.org/abs/2205.14258). | ||
|
||
# Overview | ||
|
||
This repository is currently organized into a module `model_symmetries` with | ||
submodules `stitching` and `alignment`, corresponding to sections 4 and 5 of | ||
the paper (for the network dissection results of section 6 we used the | ||
implementation at | ||
[https://github.com/CSAILVision/NetDissect-Lite](https://github.com/CSAILVision/NetDissect-Lite)). | ||
|
||
In addition there are some submodules containing code shared across `stitching` | ||
and `alignment`, namely | ||
- `models.py,` `datasets.py`, `train.py` and `plotting.py` (self explanatory) | ||
- `zoo.py`: utilities to train a bunch of models from independent random seeds | ||
- `constants.py`: specify a directory in which to store data/models/results by | ||
defining the variable `data_dir`. | ||
|
||
## `stitching` | ||
|
||
The key classes for stitching layers and stitched models are in `stitching.py`. | ||
In particular, we direct attention towards the `Birkhoff` class, which | ||
implements for our approach using PGD on the Birkhoff polytope of doubly | ||
stochastic matrices. | ||
|
||
`train.py` has more options than is typical, due to a few major implementation | ||
considerations: | ||
1. The need to make sure that when stitching, we *only* update parameters of the | ||
stitching layer. | ||
2. The overhead of PGD and extra $-\ell_2$ regularization. | ||
3. The necessity of a no-grad training epoch before validation. | ||
|
||
The main experiment script is `cifar10_stitching.py`. This also has many | ||
options, due to the number of combinations of model/stitching layer type we | ||
consider. | ||
|
||
In order to run the experiments stitching Compact Convolutional Transformers, | ||
you will need | ||
[https://github.com/SHI-Labs/Compact-Transformers](https://github.com/SHI-Labs/Compact-Transformers), | ||
which is included as a Git submodule of this repository at | ||
`model_symmetries/ct`. To initialize and update it, run | ||
``` bash | ||
git submodule init && git submodule update | ||
``` | ||
|
||
## `alignment` | ||
|
||
Core functions are located in `alignment.py`. The $G_{\mathrm{ReLU}}$-Procrustes | ||
and CKA metrics are `wreath_{procrustes,cka}` (the group $G_{\mathrm{ReLU}}$ is | ||
an example of a [wreath product](https://en.wikipedia.org/wiki/Wreath_product), | ||
hence the name). | ||
|
||
## Visualization | ||
|
||
`plotting.py` contains functions for displaying stitching penalties and | ||
dissimilarity metrics, which can be run in the notebook `plotting.ipynb`. | ||
|
||
## Parallelization | ||
|
||
We ran these experiments on a cluster managed by | ||
[SLURM](https://slurm.schedmd.com/documentation.html) -- files ending in | ||
`.slurm` are SLURM batch files. In order to distribute the many sweeps in these | ||
experiments across nodes of the cluster, we submitted batches to the queue using | ||
loops found in the bash scripts (files ending in `.sh`). **WARNING**: executing | ||
these scripts will consume many GPU days. | ||
|
||
# Citation | ||
|
||
If you find this code useful, please cite our paper. | ||
|
||
```bibtex | ||
@article{modelsyms2022, | ||
doi = {10.48550/ARXIV.2205.14258}, | ||
url = {https://arxiv.org/abs/2205.14258}, | ||
author = {Godfrey, Charles and Brown, Davis and Emerson, Tegan and Kvinge, Henry}, | ||
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences}, | ||
title = {On the Symmetries of Deep Learning Models and their Internal Representations}, | ||
publisher = {arXiv}, | ||
year = {2022}, | ||
copyright = {Creative Commons Attribution 4.0 International} | ||
} | ||
``` | ||
|
||
# Notice | ||
|
||
This research was supported by the Mathematics for Artificial Reasoning in Science (MARS) | ||
initiative at Pacific Northwest National Laboratory. It was conducted under the Laboratory Directed | ||
Research and Development (LDRD) Program at at Pacific Northwest National Laboratory (PNNL), a | ||
multiprogram National Laboratory operated by Battelle Memorial Institute for the U.S. Department | ||
of Energy under Contract DE-AC05-76RL01830. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
# {your SBATCH header here} | ||
|
||
# your environment specifications here (e.t. activate your conda env) | ||
|
||
python -m model_symmetries.alignment.cifar10_alignment --repeats=$3 --arch=$1 --similarity=$2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/usr/bin/env bash | ||
|
||
# procrustes | ||
for a in {'mCNN','resnet20','resnet18','cwCNN'}; do | ||
for d in {'procrustes','ortho_procrustes'}; do | ||
sbatch --job-name=$a$d -o $a$d'.out' -e $a$d'.out' align.slurm $a $d 32; | ||
done; | ||
done | ||
|
||
# cka | ||
for a in {'mCNN','resnet20','resnet18','cwCNN'}; do | ||
for d in {'cka','ortho_cka'}; do | ||
sbatch --job-name=$a$d -o $a$d'.out' -e $a$d'.out' align.slurm $a $d 16; | ||
done; | ||
done |
Empty file.
Empty file.
Oops, something went wrong.