Skip to content

Commit

Permalink
Merge pull request #438 from kavanase/main
Browse files Browse the repository at this point in the history
Minor Updates (`n_train/val` as percent, pre-commit, restart handling)
  • Loading branch information
Linux-cpp-lisp authored Jul 2, 2024
2 parents a465026 + 27dcae0 commit 2cdd8c5
Show file tree
Hide file tree
Showing 20 changed files with 269 additions and 40 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Black Check
uses: psf/black@stable
with:
version: "22.3.0"
version: "24.4.2"

flake8:
runs-on: ubuntu-latest
Expand All @@ -29,7 +29,7 @@ jobs:
python-version: '3.x'
- name: Install flake8
run: |
pip install flake8==7.0.0
pip install flake8==7.1.0
- name: run flake8
run: |
flake8 . --count --show-source --statistics
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 24.4.2
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
rev: 4.0.1
- repo: https://github.com/pycqa/flake8
rev: 7.1.0
hooks:
- id: flake8
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

Most recent change on the bottom.


## Unreleased - 0.6.1
### Added
- add support for equivariance testing of arbitrary Cartesian tensor outputs
- [Breaking] use entry points for `nequip.extension`s (e.g. for field registration)
- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin`
- Allow `n_train` and `n_val` to be specified as percentages of datasets.
- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases)

### Changed
- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported

### Fixed
- Fixed `flake8` install location in `pre-commit-config.yaml`


## [0.6.0] - 2024-5-10
### Added
Expand Down
3 changes: 3 additions & 0 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ save_ema_checkpoint_freq: -1
# training
n_train: 100 # number of training data
n_val: 50 # number of validation data
# alternatively, n_train and n_val can be set as percentages of the dataset size:
# n_train: 70% # 70% of dataset
# n_val: 30% # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set)
learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
validation_batch_size: 10 # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.
Expand Down
10 changes: 6 additions & 4 deletions nequip/data/_dataset/_ase_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def _ase_dataset_reader(
datas.append(
(
global_index,
AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
if global_index in include_frames
# in-memory dataset will ignore this later, but needed for indexing to work out
else None,
(
AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
if global_index in include_frames
# in-memory dataset will ignore this later, but needed for indexing to work out
else None
),
)
)
# Save to a tempfile---
Expand Down
1 change: 1 addition & 0 deletions nequip/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module.
"""

import sys
from typing import List

Expand Down
1 change: 1 addition & 0 deletions nequip/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class PartialSampler(Sampler[int]):
If `None`, defaults to `len(data_source)`.
generator (Generator): Generator used in sampling.
"""

data_source: Dataset
num_samples_per_epoch: int
shuffle: bool
Expand Down
6 changes: 3 additions & 3 deletions nequip/nn/_convnetlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def __init__(
# updated with whatever the convolution outputs (which is a full graph module)
self.irreps_out.update(self.conv.irreps_out)
# but with the features updated by the nonlinearity
self.irreps_out[
AtomicDataDict.NODE_FEATURES_KEY
] = self.equivariant_nonlin.irreps_out
self.irreps_out[AtomicDataDict.NODE_FEATURES_KEY] = (
self.equivariant_nonlin.irreps_out
)

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# save old features for resnet
Expand Down
3 changes: 3 additions & 0 deletions nequip/nn/_grad_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module):
out_field: the field in which to return the computed gradients. Defaults to ``f"d({of})/d({wrt})"`` for each field in ``wrt``.
sign: either 1 or -1; the returned gradient is multiplied by this.
"""

sign: float
_negate: bool
skip: bool
Expand Down Expand Up @@ -119,6 +120,7 @@ class PartialForceOutput(GraphModuleMixin, torch.nn.Module):
vectorize: the vectorize option to ``torch.autograd.functional.jacobian``,
false by default since it doesn't work well.
"""

vectorize: bool

def __init__(
Expand Down Expand Up @@ -183,6 +185,7 @@ class StressOutput(GraphModuleMixin, torch.nn.Module):
func: the energy model to wrap
do_forces: whether to compute forces as well
"""

do_forces: bool

def __init__(
Expand Down
1 change: 1 addition & 0 deletions nequip/nn/_interaction_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Interaction Block """

from typing import Optional, Dict, Callable

import torch
Expand Down
8 changes: 5 additions & 3 deletions nequip/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,11 @@ def main(args=None, running_as_script: bool = True):
if do_metrics:
display_bar = context_stack.enter_context(
tqdm(
bar_format=""
if prog.disable # prog.ncols doesn't exist if disabled
else ("{desc:." + str(prog.ncols) + "}"),
bar_format=(
""
if prog.disable # prog.ncols doesn't exist if disabled
else ("{desc:." + str(prog.ncols) + "}")
),
disable=None,
)
)
Expand Down
15 changes: 13 additions & 2 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Train a network."""

import logging
import argparse
import warnings
Expand All @@ -7,7 +8,8 @@
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
import numpy as np # noqa: F401

from os.path import isdir
from os.path import exists, isdir
from shutil import rmtree
from pathlib import Path

import torch
Expand Down Expand Up @@ -71,12 +73,21 @@ def main(args=None, running_as_script: bool = True):
if running_as_script:
set_up_script_logger(config.get("log", None), config.verbose)

found_restart_file = isdir(f"{config.root}/{config.run_name}")
found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
if found_restart_file and not config.append:
raise RuntimeError(
f"Training instance exists at {config.root}/{config.run_name}; "
"either set append to True or use a different root or runname"
)
elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
# first training epoch (usually due to memory):
warnings.warn(
f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
f"be started."
)
rmtree(f"{config.root}/{config.run_name}")

# for fresh new train
if not found_restart_file:
Expand Down
86 changes: 70 additions & 16 deletions nequip/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
make an interface with ray
"""

import sys
import inspect
import logging
Expand Down Expand Up @@ -107,7 +108,7 @@ class Trainer:
- "trainer_save.pth": all the training information. The file used for loading and restart
For restart run, the default set up is to not append to the original folders and files.
The Output class will automatically build a folder call root/run_name
The Output class will automatically build a folder called ``root/run_name``
If append mode is on, the log file will be appended and the best model and last model will be overwritten.
More examples can be found in tests/train/test_trainer.py
Expand Down Expand Up @@ -157,9 +158,9 @@ class Trainer:
batch_size (int): size of each batch
validation_batch_size (int): batch size for evaluating the model for validation
shuffle (bool): parameters for dataloader
n_train (int): # of frames for training
n_train (int, str): # of frames for training (as int, or as a percentage string)
n_train_per_epoch (optional int): how many frames from `n_train` to use each epoch; see `PartialSampler`. When `None`, all `n_train` frames will be used each epoch.
n_val (int): # of frames for validation
n_val (int), str: # of frames for validation (as int, or as a percentage string)
exclude_keys (list): fields from dataset to ignore.
dataloader_num_workers (int): `num_workers` for the `DataLoader`s
train_idcs (optional, list): list of frames to use for training
Expand Down Expand Up @@ -250,9 +251,9 @@ def __init__(
batch_size: int = 5,
validation_batch_size: int = 5,
shuffle: bool = True,
n_train: Optional[int] = None,
n_train: Optional[Union[int, str]] = None,
n_train_per_epoch: Optional[int] = None,
n_val: Optional[int] = None,
n_val: Optional[Union[int, str]] = None,
dataloader_num_workers: int = 0,
train_idcs: Optional[list] = None,
val_idcs: Optional[list] = None,
Expand Down Expand Up @@ -754,7 +755,6 @@ def init_metrics(self):
)

def train(self):

"""Training"""
if getattr(self, "dl_train", None) is None:
raise RuntimeError("You must call `set_dataset()` before calling `train()`")
Expand Down Expand Up @@ -1144,12 +1144,59 @@ def __del__(self):
for i in range(len(logger.handlers)):
logger.handlers.pop()

def _parse_n_train_n_val(
self, train_dataset_size: int, val_dataset_size: int
) -> Tuple[int, int]:
# parse n_train and n_val (can be ints or str with percentage):
n_train_n_val = []
for n_name, dataset_size in (
("n_train", train_dataset_size),
("n_val", val_dataset_size),
):
n = getattr(self, n_name)
if isinstance(n, str) and "%" in n:
n_train_n_val.append(
(float(n.rstrip("%")) / 100) * dataset_size
) # convert to float first
elif isinstance(n, int):
n_train_n_val.append(n)
else:
raise ValueError(
f"Invalid value/type for {n_name}: {n} -- must be either int or str with %!"
)

floored_n_train_n_val = [int(n) for n in n_train_n_val]
for n, n_name in zip(floored_n_train_n_val, ["n_train", "n_val"]):
if n < 1:
raise ValueError(f"{n_name} must be at least 1! Got {n}.")

# if n_train and n_val were both set as percentages which summed to 100%, make sure that sum of
# floored values comes to 100% of dataset size (i.e. that flooring doesn't omit a frame)
if (
train_dataset_size == val_dataset_size
and isinstance(self.n_train, str)
and isinstance(self.n_val, str)
and np.isclose(
float(self.n_train.strip("%")) + float(self.n_val.strip("%")), 100
)
):
if (
sum(floored_n_train_n_val) != train_dataset_size
): # one frame was cut, add to larger of the
# two float values (i.e. round up the percentage which gave a >= x.5 float value)
floored_n_train_n_val[
np.argmax(n_train_n_val)
] += train_dataset_size - sum(floored_n_train_n_val)

return tuple(floored_n_train_n_val)

def set_dataset(
self,
dataset: AtomicDataset,
validation_dataset: Optional[AtomicDataset] = None,
) -> None:
"""Set the dataset(s) used by this trainer.
"""
Set the dataset(s) used by this trainer.
Training and validation datasets will be sampled from
them in accordance with the trainer's parameters.
Expand All @@ -1163,7 +1210,10 @@ def set_dataset(
if validation_dataset is None:
# Sample both from `dataset`:
total_n = len(dataset)
if (self.n_train + self.n_val) > total_n:
n_train, n_val = self._parse_n_train_n_val(
train_dataset_size=total_n, val_dataset_size=total_n
)
if (n_train + n_val) > total_n:
raise ValueError(
"too little data for training and validation. please reduce n_train and n_val"
)
Expand All @@ -1177,25 +1227,29 @@ def set_dataset(
f"splitting mode {self.train_val_split} not implemented"
)

self.train_idcs = idcs[: self.n_train]
self.val_idcs = idcs[self.n_train : self.n_train + self.n_val]
self.train_idcs = idcs[:n_train]
self.val_idcs = idcs[n_train : n_train + n_val]
else:
if self.n_train > len(dataset):
n_train, n_val = self._parse_n_train_n_val(
train_dataset_size=len(dataset),
val_dataset_size=len(validation_dataset),
)
if n_train > len(dataset):
raise ValueError("Not enough data in dataset for requested n_train")
if self.n_val > len(validation_dataset):
if n_val > len(validation_dataset):
raise ValueError(
"Not enough data in validation dataset for requested n_val"
)
if self.train_val_split == "random":
self.train_idcs = torch.randperm(
len(dataset), generator=self.dataset_rng
)[: self.n_train]
)[:n_train]
self.val_idcs = torch.randperm(
len(validation_dataset), generator=self.dataset_rng
)[: self.n_val]
)[:n_val]
elif self.train_val_split == "sequential":
self.train_idcs = torch.arange(self.n_train)
self.val_idcs = torch.arange(self.n_val)
self.train_idcs = torch.arange(n_train)
self.val_idcs = torch.arange(n_val)
else:
raise NotImplementedError(
f"splitting mode {self.train_val_split} not implemented"
Expand Down
1 change: 1 addition & 0 deletions nequip/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
If a parameter is updated, the updated value will be formatted back to the same type.
"""

from typing import Set, Dict, Any, List

import inspect
Expand Down
1 change: 1 addition & 0 deletions nequip/utils/savenload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
utilities that involve file searching and operations (i.e. save/load)
"""

from typing import Union, List, Tuple, Optional, Callable
import sys
import logging
Expand Down
8 changes: 5 additions & 3 deletions tests/integration/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def runit(params: dict):
assert np.allclose(
err,
0.0,
atol=1e-8
if true_identity
else (1e-2 if metric.startswith("e") else 1e-4),
atol=(
1e-8
if true_identity
else (1e-2 if metric.startswith("e") else 1e-4)
),
), f"Metric `{metric}` wasn't zero!"
elif builder == ConstFactorModel:
# TODO: check comperable to naive numpy compute
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/model/test_pair/test_zbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_lammps_repro(self, config):
# $ lmp -in zbl_data.lmps
# $ python -c "import numpy as np; d = np.loadtxt('zbl.dat', skiprows=1); np.save('zbl.npy', d)"
refdata = np.load(Path(__file__).parent / "zbl.npy")
for (r, Zi, Zj, pe, fxi, fxj) in refdata:
for r, Zi, Zj, pe, fxi, fxj in refdata:
if r >= r_max:
continue
atoms.positions[1, 0] = r
Expand Down
Loading

0 comments on commit 2cdd8c5

Please sign in to comment.