Skip to content

Commit

Permalink
Merge pull request #3 from dlmbl/demo
Browse files Browse the repository at this point in the history
Create knowledge extraction demo project
  • Loading branch information
msschwartz21 authored Aug 29, 2024
2 parents e81b21b + 5eb3fa2 commit 211e167
Show file tree
Hide file tree
Showing 12 changed files with 650 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cookiecutter.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"full_name": "Your Name",
"email": "[email protected]",
"project_name": "dlmbl project",
"project_name": "dlmbl-project",
"project_slug": "{{ cookiecutter.project_name.lower().replace(' ', '_').replace('-', '_') }}",
"project_short_description": "package description.",
"default_python": "3.10"
Expand Down
17 changes: 17 additions & 0 deletions knowledge_extraction/.cruft.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"template": "https://github.com/dlmbl/example-project",
"commit": "4fff80ba77169cfa6592cf6c9473e6ca98be4da2",
"checkout": null,
"context": {
"cookiecutter": {
"full_name": "Diane Adjavon",
"email": "adjavond [at] janelia [dot] hhmi [dot] org",
"project_name": "knowledge extraction",
"project_slug": "knowledge_extraction",
"project_short_description": "A repo to extract knowledge from a classifier",
"default_python": "3.10",
"_template": "https://github.com/dlmbl/example-project"
}
},
"directory": null
}
132 changes: 132 additions & 0 deletions knowledge_extraction/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
data/
checkpoints/

# data files
*.zarr
*.tiff
*.tif

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Pdoc documentation
docs/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
pyrepo
.vscode/

# OS Files
.DS_Store
11 changes: 11 additions & 0 deletions knowledge_extraction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# knowledge extraction

The goal of this project is to train a simple classifier to classify images in the Colored MNIST dataset. It was adapted from the [DL@MBL Knowledge Extraction exercise](https://github.com/dlmbl/knowledge_extraction).

## Getting started
To create a new python environment and install the package:
```bash
mamba create -n knowledge python=3.10
mamba activate knowledge
pip install -e .
```
247 changes: 247 additions & 0 deletions knowledge_extraction/notebooks/2024-08-26-test-model.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions knowledge_extraction/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools", "wheel"]

[project]
name = "knowledge-extraction"
description = "A repo to extract knowledge from a classifier"
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
]
keywords = []
license = { text = "BSD 3-Clause License" }
authors = [
{ email = "[email protected] ", name = "Diane Adjavon " },
]
dynamic = ["version"]
dependencies = [
'matplotlib',
'torch',
'torchvision',
'tqdm',
'scikit-learn',
'seaborn'
]
47 changes: 47 additions & 0 deletions knowledge_extraction/scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
This script was used to train the pre-trained model weights that were given as an option during the exercise.
"""

from knowledge_extraction.model import DenseModel
from knowledge_extraction.data import ColoredMNIST
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path


def train_classifier(base_dir, epochs=10):
checkpoint_dir = Path(base_dir) / "../checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
data_dir = Path(base_dir) / "../data"
data_dir.mkdir(exist_ok=True)
#
model = DenseModel((28, 28, 3), 4)
data = ColoredMNIST(data_dir, download=True, train=True)
dataloader = DataLoader(data, batch_size=32, shuffle=True, pin_memory=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

losses = []
for epoch in range(epochs):
for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"):
optimizer.zero_grad()
y_pred = model(x.to(device))
loss = loss_fn(y_pred, y.to(device))
loss.backward()
optimizer.step()
print(f"Epoch {epoch}: Loss = {loss.item()}")
losses.append(loss.item())
# TODO save every epoch instead of overwriting?
torch.save(model.state_dict(), checkpoint_dir / "model.pth")

with open(checkpoint_dir / "losses.txt", "w") as f:
f.write("\n".join(str(l) for l in losses))


if __name__ == "__main__":
this_dir = Path(__file__).parent
train_classifier(base_dir=this_dir, epochs=10)
47 changes: 47 additions & 0 deletions knowledge_extraction/scripts/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
This script was used to validate the pre-trained classifier.
"""

from knowledge_extraction.model import DenseModel
from knowledge_extraction.data import ColoredMNIST
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm


def confusion_matrix(labels, predictions):
n_classes = len(set(labels))
matrix = np.zeros((n_classes, n_classes))
for label, pred in zip(labels, predictions):
matrix[label, pred] += 1
return matrix


def validate_classifier(checkpoint_dir):
data = ColoredMNIST("data", download=False, train=False)
dataloader = DataLoader(
data, batch_size=32, shuffle=False, pin_memory=True, drop_last=False
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseModel((28, 28, 3), 4)
model.to(device)
model.load_state_dict(torch.load(f"{checkpoint_dir}/model.pth", weights_only=True))

labels = []
predictions = []
for x, y in tqdm(dataloader, desc=f"Validation"):
pred = model(x.to(device))
pred_y = torch.argmax(pred, dim=1)
labels.extend(y.numpy())
predictions.extend(pred_y.cpu().numpy())

# Get confusion matrix
matrix = confusion_matrix(labels, predictions)
# Save matrix as text
np.savetxt(f"{checkpoint_dir}/confusion_matrix.txt", matrix, fmt="%d")


if __name__ == "__main__":
validate_classifier(checkpoint_dir="checkpoints")
Empty file.
85 changes: 85 additions & 0 deletions knowledge_extraction/src/knowledge_extraction/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Colored MNIST Data set."""

from pathlib import Path

import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

download_path = Path(__file__).parent / "downloads"


def colorize(image, color):
"""Turn a grayscale image into a single-colored image."""
return torch.stack(tuple(image * x for x in color), dim=1).squeeze()


def get_color(condition_labels, condition, sample):
"""Get matplotlib color based on condition and random sample.
Parameters
----------
condition_labels: List[str]
List of available conditions; i.e. `matplotlib` colormaps.
condition: int
The index of the condition
sample: float
Sampling value for the colormap, must be between 0 and 1.
Returns
-------
color: np.array
(3,) array of RGB values
"""
color = plt.cm.get_cmap(condition_labels[condition])(sample)[:-1]
return color


class ColoredMNIST(torchvision.datasets.MNIST):
"""MNIST with added color.
The original MNIST images make up the content of the data set.
They are styled with colors sampled from `matplotlib` colormaps.
The colormaps correspond to the data's condition.
"""

def __init__(self, root, classes=None, train=True, download=False):
"""
Parameters
----------
root: Union[str, pathlib.Path]
Data root for download; defaults to ./downloads
classes: List[str]
The names of the `matplotlib` colormaps to use; defaults to the
conditions: `['spring', 'summer', 'autumn', 'winter']`.
train: bool
Passed to `torchvision.datasets.MNIST`; default is True
download: bool
Passed to `torchvision.datasets.MNIST`; default is True
"""
super().__init__(root, train=train, download=download)
if classes is None:
self.classes = ["spring", "summer", "autumn", "winter"]
else:
self.classes = classes
# Initialise a random set of conditions, of the same length as the data
self.conditions = torch.randint(len(self.classes), (len(self),))
# Initialise a set of style values, the actual color will be dependent
# on the condition
self.style_values = torch.rand((len(self),))
self.colors = [
get_color(self.classes, condition, sample)
for condition, sample in zip(
self.conditions.numpy(), self.style_values.numpy()
)
]

def __getitem__(self, item):
image, label = super().__getitem__(item)
image = transforms.ToTensor()(image)
color = torch.Tensor(self.colors[item])
condition = self.conditions[item]
label = torch.tensor(label)
return colorize(image, color), condition
Loading

0 comments on commit 211e167

Please sign in to comment.