-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from dlmbl/demo
Create knowledge extraction demo project
- Loading branch information
Showing
12 changed files
with
650 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
{ | ||
"full_name": "Your Name", | ||
"email": "[email protected]", | ||
"project_name": "dlmbl project", | ||
"project_name": "dlmbl-project", | ||
"project_slug": "{{ cookiecutter.project_name.lower().replace(' ', '_').replace('-', '_') }}", | ||
"project_short_description": "package description.", | ||
"default_python": "3.10" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
{ | ||
"template": "https://github.com/dlmbl/example-project", | ||
"commit": "4fff80ba77169cfa6592cf6c9473e6ca98be4da2", | ||
"checkout": null, | ||
"context": { | ||
"cookiecutter": { | ||
"full_name": "Diane Adjavon", | ||
"email": "adjavond [at] janelia [dot] hhmi [dot] org", | ||
"project_name": "knowledge extraction", | ||
"project_slug": "knowledge_extraction", | ||
"project_short_description": "A repo to extract knowledge from a classifier", | ||
"default_python": "3.10", | ||
"_template": "https://github.com/dlmbl/example-project" | ||
} | ||
}, | ||
"directory": null | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
data/ | ||
checkpoints/ | ||
|
||
# data files | ||
*.zarr | ||
*.tiff | ||
*.tif | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Pdoc documentation | ||
docs/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
pyrepo | ||
.vscode/ | ||
|
||
# OS Files | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# knowledge extraction | ||
|
||
The goal of this project is to train a simple classifier to classify images in the Colored MNIST dataset. It was adapted from the [DL@MBL Knowledge Extraction exercise](https://github.com/dlmbl/knowledge_extraction). | ||
|
||
## Getting started | ||
To create a new python environment and install the package: | ||
```bash | ||
mamba create -n knowledge python=3.10 | ||
mamba activate knowledge | ||
pip install -e . | ||
``` |
247 changes: 247 additions & 0 deletions
247
knowledge_extraction/notebooks/2024-08-26-test-model.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
[build-system] | ||
build-backend = "setuptools.build_meta" | ||
requires = ["setuptools", "wheel"] | ||
|
||
[project] | ||
name = "knowledge-extraction" | ||
description = "A repo to extract knowledge from a classifier" | ||
readme = "README.md" | ||
requires-python = ">=3.7" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
] | ||
keywords = [] | ||
license = { text = "BSD 3-Clause License" } | ||
authors = [ | ||
{ email = "[email protected] ", name = "Diane Adjavon " }, | ||
] | ||
dynamic = ["version"] | ||
dependencies = [ | ||
'matplotlib', | ||
'torch', | ||
'torchvision', | ||
'tqdm', | ||
'scikit-learn', | ||
'seaborn' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
This script was used to train the pre-trained model weights that were given as an option during the exercise. | ||
""" | ||
|
||
from knowledge_extraction.model import DenseModel | ||
from knowledge_extraction.data import ColoredMNIST | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
from pathlib import Path | ||
|
||
|
||
def train_classifier(base_dir, epochs=10): | ||
checkpoint_dir = Path(base_dir) / "../checkpoints" | ||
checkpoint_dir.mkdir(exist_ok=True) | ||
data_dir = Path(base_dir) / "../data" | ||
data_dir.mkdir(exist_ok=True) | ||
# | ||
model = DenseModel((28, 28, 3), 4) | ||
data = ColoredMNIST(data_dir, download=True, train=True) | ||
dataloader = DataLoader(data, batch_size=32, shuffle=True, pin_memory=True) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model.to(device) | ||
|
||
losses = [] | ||
for epoch in range(epochs): | ||
for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): | ||
optimizer.zero_grad() | ||
y_pred = model(x.to(device)) | ||
loss = loss_fn(y_pred, y.to(device)) | ||
loss.backward() | ||
optimizer.step() | ||
print(f"Epoch {epoch}: Loss = {loss.item()}") | ||
losses.append(loss.item()) | ||
# TODO save every epoch instead of overwriting? | ||
torch.save(model.state_dict(), checkpoint_dir / "model.pth") | ||
|
||
with open(checkpoint_dir / "losses.txt", "w") as f: | ||
f.write("\n".join(str(l) for l in losses)) | ||
|
||
|
||
if __name__ == "__main__": | ||
this_dir = Path(__file__).parent | ||
train_classifier(base_dir=this_dir, epochs=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
This script was used to validate the pre-trained classifier. | ||
""" | ||
|
||
from knowledge_extraction.model import DenseModel | ||
from knowledge_extraction.data import ColoredMNIST | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
|
||
|
||
def confusion_matrix(labels, predictions): | ||
n_classes = len(set(labels)) | ||
matrix = np.zeros((n_classes, n_classes)) | ||
for label, pred in zip(labels, predictions): | ||
matrix[label, pred] += 1 | ||
return matrix | ||
|
||
|
||
def validate_classifier(checkpoint_dir): | ||
data = ColoredMNIST("data", download=False, train=False) | ||
dataloader = DataLoader( | ||
data, batch_size=32, shuffle=False, pin_memory=True, drop_last=False | ||
) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model = DenseModel((28, 28, 3), 4) | ||
model.to(device) | ||
model.load_state_dict(torch.load(f"{checkpoint_dir}/model.pth", weights_only=True)) | ||
|
||
labels = [] | ||
predictions = [] | ||
for x, y in tqdm(dataloader, desc=f"Validation"): | ||
pred = model(x.to(device)) | ||
pred_y = torch.argmax(pred, dim=1) | ||
labels.extend(y.numpy()) | ||
predictions.extend(pred_y.cpu().numpy()) | ||
|
||
# Get confusion matrix | ||
matrix = confusion_matrix(labels, predictions) | ||
# Save matrix as text | ||
np.savetxt(f"{checkpoint_dir}/confusion_matrix.txt", matrix, fmt="%d") | ||
|
||
|
||
if __name__ == "__main__": | ||
validate_classifier(checkpoint_dir="checkpoints") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Colored MNIST Data set.""" | ||
|
||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import torch | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader, random_split | ||
|
||
download_path = Path(__file__).parent / "downloads" | ||
|
||
|
||
def colorize(image, color): | ||
"""Turn a grayscale image into a single-colored image.""" | ||
return torch.stack(tuple(image * x for x in color), dim=1).squeeze() | ||
|
||
|
||
def get_color(condition_labels, condition, sample): | ||
"""Get matplotlib color based on condition and random sample. | ||
Parameters | ||
---------- | ||
condition_labels: List[str] | ||
List of available conditions; i.e. `matplotlib` colormaps. | ||
condition: int | ||
The index of the condition | ||
sample: float | ||
Sampling value for the colormap, must be between 0 and 1. | ||
Returns | ||
------- | ||
color: np.array | ||
(3,) array of RGB values | ||
""" | ||
color = plt.cm.get_cmap(condition_labels[condition])(sample)[:-1] | ||
return color | ||
|
||
|
||
class ColoredMNIST(torchvision.datasets.MNIST): | ||
"""MNIST with added color. | ||
The original MNIST images make up the content of the data set. | ||
They are styled with colors sampled from `matplotlib` colormaps. | ||
The colormaps correspond to the data's condition. | ||
""" | ||
|
||
def __init__(self, root, classes=None, train=True, download=False): | ||
""" | ||
Parameters | ||
---------- | ||
root: Union[str, pathlib.Path] | ||
Data root for download; defaults to ./downloads | ||
classes: List[str] | ||
The names of the `matplotlib` colormaps to use; defaults to the | ||
conditions: `['spring', 'summer', 'autumn', 'winter']`. | ||
train: bool | ||
Passed to `torchvision.datasets.MNIST`; default is True | ||
download: bool | ||
Passed to `torchvision.datasets.MNIST`; default is True | ||
""" | ||
super().__init__(root, train=train, download=download) | ||
if classes is None: | ||
self.classes = ["spring", "summer", "autumn", "winter"] | ||
else: | ||
self.classes = classes | ||
# Initialise a random set of conditions, of the same length as the data | ||
self.conditions = torch.randint(len(self.classes), (len(self),)) | ||
# Initialise a set of style values, the actual color will be dependent | ||
# on the condition | ||
self.style_values = torch.rand((len(self),)) | ||
self.colors = [ | ||
get_color(self.classes, condition, sample) | ||
for condition, sample in zip( | ||
self.conditions.numpy(), self.style_values.numpy() | ||
) | ||
] | ||
|
||
def __getitem__(self, item): | ||
image, label = super().__getitem__(item) | ||
image = transforms.ToTensor()(image) | ||
color = torch.Tensor(self.colors[item]) | ||
condition = self.conditions[item] | ||
label = torch.tensor(label) | ||
return colorize(image, color), condition |
Oops, something went wrong.