Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
godfrey-cw committed Nov 8, 2022
0 parents commit 8dcefbe
Show file tree
Hide file tree
Showing 40 changed files with 5,426 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.pyc
*.out
**/__pycache__/
# !plots
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "model_symmetries/ct"]
path = model_symmetries/ct
url = https://github.com/SHI-Labs/Compact-Transformers.git
21 changes: 21 additions & 0 deletions DISCLAIMER
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
This material was prepared as an account of work sponsored by an agency of the
United States Government. Neither the United States Government nor the United
States Department of Energy, nor Battelle, nor any of their employees, nor any
jurisdiction or organization that has cooperated in the development of these
materials, makes any warranty, express or implied, or assumes any legal
liability or responsibility for the accuracy, completeness, or usefulness or any
information, apparatus, product, software, or process disclosed, or represents
that its use would not infringe privately owned rights. Reference herein to any
specific commercial product, process, or service by trade name, trademark,
manufacturer, or otherwise does not necessarily constitute or imply its
endorsement, recommendation, or favoring by the United States Government or any
agency thereof, or Battelle Memorial Institute. The views and opinions of
authors expressed herein do not necessarily state or reflect those of the United
States Government or any agency thereof.

PACIFIC NORTHWEST NATIONAL LABORATORY
operated by
BATTELLE
for the
UNITED STATES DEPARTMENT OF ENERGY
under Contract DE-AC05-76RL01830
24 changes: 24 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Simplified BSD
____________________________________________
Copyright 2022 Battelle Memorial Institute

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
91 changes: 91 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
Code to run the experiments of the Neurips 2022 paper [On the Symmetries of Deep
Learning Models and their Internal
Representations](https://arxiv.org/abs/2205.14258).

# Overview

This repository is currently organized into a module `model_symmetries` with
submodules `stitching` and `alignment`, corresponding to sections 4 and 5 of
the paper (for the network dissection results of section 6 we used the
implementation at
[https://github.com/CSAILVision/NetDissect-Lite](https://github.com/CSAILVision/NetDissect-Lite)).

In addition there are some submodules containing code shared across `stitching`
and `alignment`, namely
- `models.py,` `datasets.py`, `train.py` and `plotting.py` (self explanatory)
- `zoo.py`: utilities to train a bunch of models from independent random seeds
- `constants.py`: specify a directory in which to store data/models/results by
defining the variable `data_dir`.

## `stitching`

The key classes for stitching layers and stitched models are in `stitching.py`.
In particular, we direct attention towards the `Birkhoff` class, which
implements for our approach using PGD on the Birkhoff polytope of doubly
stochastic matrices.

`train.py` has more options than is typical, due to a few major implementation
considerations:
1. The need to make sure that when stitching, we *only* update parameters of the
stitching layer.
2. The overhead of PGD and extra $-\ell_2$ regularization.
3. The necessity of a no-grad training epoch before validation.

The main experiment script is `cifar10_stitching.py`. This also has many
options, due to the number of combinations of model/stitching layer type we
consider.

In order to run the experiments stitching Compact Convolutional Transformers,
you will need
[https://github.com/SHI-Labs/Compact-Transformers](https://github.com/SHI-Labs/Compact-Transformers),
which is included as a Git submodule of this repository at
`model_symmetries/ct`. To initialize and update it, run
``` bash
git submodule init && git submodule update
```

## `alignment`

Core functions are located in `alignment.py`. The $G_{\mathrm{ReLU}}$-Procrustes
and CKA metrics are `wreath_{procrustes,cka}` (the group $G_{\mathrm{ReLU}}$ is
an example of a [wreath product](https://en.wikipedia.org/wiki/Wreath_product),
hence the name).

## Visualization

`plotting.py` contains functions for displaying stitching penalties and
dissimilarity metrics, which can be run in the notebook `plotting.ipynb`.

## Parallelization

We ran these experiments on a cluster managed by
[SLURM](https://slurm.schedmd.com/documentation.html) -- files ending in
`.slurm` are SLURM batch files. In order to distribute the many sweeps in these
experiments across nodes of the cluster, we submitted batches to the queue using
loops found in the bash scripts (files ending in `.sh`). **WARNING**: executing
these scripts will consume many GPU days.

# Citation

If you find this code useful, please cite our paper.

```bibtex
@article{modelsyms2022,
doi = {10.48550/ARXIV.2205.14258},
url = {https://arxiv.org/abs/2205.14258},
author = {Godfrey, Charles and Brown, Davis and Emerson, Tegan and Kvinge, Henry},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {On the Symmetries of Deep Learning Models and their Internal Representations},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
```

# Notice

This research was supported by the Mathematics for Artificial Reasoning in Science (MARS)
initiative at Pacific Northwest National Laboratory. It was conducted under the Laboratory Directed
Research and Development (LDRD) Program at at Pacific Northwest National Laboratory (PNNL), a
multiprogram National Laboratory operated by Battelle Memorial Institute for the U.S. Department
of Energy under Contract DE-AC05-76RL01830.
6 changes: 6 additions & 0 deletions align.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash
# {your SBATCH header here}

# your environment specifications here (e.t. activate your conda env)

python -m model_symmetries.alignment.cifar10_alignment --repeats=$3 --arch=$1 --similarity=$2
15 changes: 15 additions & 0 deletions align_launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env bash

# procrustes
for a in {'mCNN','resnet20','resnet18','cwCNN'}; do
for d in {'procrustes','ortho_procrustes'}; do
sbatch --job-name=$a$d -o $a$d'.out' -e $a$d'.out' align.slurm $a $d 32;
done;
done

# cka
for a in {'mCNN','resnet20','resnet18','cwCNN'}; do
for d in {'cka','ortho_cka'}; do
sbatch --job-name=$a$d -o $a$d'.out' -e $a$d'.out' align.slurm $a $d 16;
done;
done
Empty file added model_symmetries/__init__.py
Empty file.
Empty file.
Loading

0 comments on commit 8dcefbe

Please sign in to comment.