Skip to content

Commit

Permalink
removed statistics file as a requirement for training - close #21
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Mar 2, 2025
1 parent 43ceeb7 commit 3188706
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 77 deletions.
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ Train a model using the prepared dataset and specify the MLIP wrapper:
equitrain -v \
--train-file data/train.h5 \
--valid-file data/valid.h5 \
--statistics-file data/statistics.json \
--output-dir result \
--model mace.model \
--model-wrapper 'mace' \
Expand All @@ -161,16 +160,15 @@ from equitrain.model_wrappers import MaceWrapper

def test_train_mace():
args = get_args_parser_train().parse_args()
args.train_file = 'data/train.h5'
args.valid_file = 'data/valid.h5'
args.statistics_file = 'data/statistics.json'
args.output_dir = 'test_train_mace'
args.epochs = 10
args.batch_size = 64
args.lr = 0.01
args.verbose = 1
args.tqdm = True
args.model = MaceWrapper(args, "mace.model")
args.train_file = 'data/train.h5'
args.valid_file = 'data/valid.h5'
args.output_dir = 'test_train_mace'
args.epochs = 10
args.batch_size = 64
args.lr = 0.01
args.verbose = 1
args.tqdm = True
args.model = MaceWrapper(args, "mace.model")

train(args)

Expand All @@ -193,10 +191,9 @@ from equitrain.model_wrappers import MaceWrapper

def test_mace_predict():
args = get_args_parser_predict().parse_args()
args.predict_file = 'data/valid.h5'
args.statistics_file = 'data/statistics.json'
args.batch_size = 64
args.model = MaceWrapper(args, "mace.model")
args.predict_file = 'data/valid.h5'
args.batch_size = 64
args.model = MaceWrapper(args, "mace.model")

energy_pred, forces_pred, stress_pred = predict(args)

Expand Down
6 changes: 0 additions & 6 deletions equitrain/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ def add_common_file_args(parser: argparse.ArgumentParser) -> argparse.ArgumentPa
parser.add_argument('--train-file', help='Training data', type=str, default=None)
parser.add_argument('--valid-file', help='Validation data', type=str, default=None)
parser.add_argument('--test-file', help='Test data', type=str, default=None)
parser.add_argument(
'--statistics-file',
help='Statistics file in JSON format',
type=str,
default=None,
)
parser.add_argument(
'--output-dir', help='Output directory for h5 files', type=str, default=''
)
Expand Down
2 changes: 1 addition & 1 deletion equitrain/data/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class AtomicNumberTable(list):
def __init__(self, zs: list):
super().__init__(sorted(list(zs)))
super().__init__(zs)

@classmethod
def from_zs(cls, zs: Iterable[int]):
Expand Down
2 changes: 1 addition & 1 deletion equitrain/data/format_lmdb/lmdb_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def main():
for data in lmdb_dataset:
atoms = convert_to_ase_object(data)
atomic_numbers.update(atoms.get_atomic_numbers())
z_table = AtomicNumberTable(list(atomic_numbers))
z_table = AtomicNumberTable(sorted(list(atomic_numbers)))

# Define output file paths
output_hdf5 = 'output_data.h5'
Expand Down
2 changes: 1 addition & 1 deletion equitrain/data/format_xyz/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def update_atomic_energies(self, atoms, i):

@property
def atomic_numbers(self):
return AtomicNumberTable(self.z_set)
return AtomicNumberTable(sorted(list(self.z_set)))
60 changes: 24 additions & 36 deletions equitrain/data/loaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import torch
from accelerate import Accelerator

Expand Down Expand Up @@ -42,27 +44,16 @@ def dataloader_update_errors(


def get_dataloader(
data_file,
args,
statistics=None,
data_file: Path | str,
atomic_numbers: list[int],
r_max: float,
accelerator: Accelerator = None,
logger: FileLogger = None,
):
if data_file is None:
return None

if statistics is None:
statistics = Statistics.load(args.statistics_file)

if logger is not None:
logger.log(
1,
f'Using r_max={statistics.r_max} from statistics file `{args.statistics_file}`',
)

data_set = HDF5GraphDataset(
data_file, r_max=statistics.r_max, atomic_numbers=statistics.atomic_numbers
)
data_set = HDF5GraphDataset(data_file, r_max=r_max, atomic_numbers=atomic_numbers)

data_loader = DynamicGraphLoader(
dataset=data_set,
Expand All @@ -83,35 +74,32 @@ def get_dataloader(
return data_loader


def get_dataloaders(args, accelerator: Accelerator = None, logger: FileLogger = None):
statistics = Statistics.load(args.statistics_file)

if logger is not None:
logger.log(
1,
f'Using r_max={statistics.r_max} from statistics file `{args.statistics_file}`',
)

def get_dataloaders(
args,
atomic_numbers: list[int],
r_max: float,
accelerator: Accelerator = None,
):
train_loader = get_dataloader(
args.train_file,
args,
statistics=statistics,
accelerator=accelerator,
logger=logger,
args.train_file,
atomic_numbers,
r_max,
accelerator,
)
valid_loader = get_dataloader(
args.valid_file,
args,
statistics=statistics,
accelerator=accelerator,
logger=logger,
args.valid_file,
atomic_numbers,
r_max,
accelerator,
)
test_loader = get_dataloader(
args.test_file,
args,
statistics=statistics,
accelerator=accelerator,
logger=logger,
args.test_file,
atomic_numbers,
r_max,
accelerator,
)

return train_loader, valid_loader, test_loader
2 changes: 1 addition & 1 deletion equitrain/data/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def compute_atomic_numbers(
# Convert from int64 to int32, which is json serializable
z_set.update([int(z) for z in batch[0].get_atomic_numbers()])

return AtomicNumberTable(z_set)
return AtomicNumberTable(sorted(list(z_set)))


def compute_average_atomic_energies(
Expand Down
20 changes: 20 additions & 0 deletions equitrain/model_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from equitrain.data.atomic import AtomicNumberTable


class MaceWrapper(torch.nn.Module):
def __init__(self, args, model, optimize_atomic_energies=False):
Expand Down Expand Up @@ -30,6 +32,14 @@ def forward(self, *args):

return y_pred

@property
def atomic_numbers(self):
return AtomicNumberTable(self.model.atomic_numbers.tolist())

@property
def r_max(self):
return self.model.r_max.item()


class SevennetWrapper(torch.nn.Module):
def __init__(self, args, model):
Expand Down Expand Up @@ -93,3 +103,13 @@ def batch_voigt_to_tensor(cls, voigts):
tensors[:, 0, 2] = tensors[:, 2, 0] = voigts[:, 4] # σ_xz
tensors[:, 0, 1] = tensors[:, 1, 0] = voigts[:, 5] # σ_xy
return tensors

@property
def atomic_numbers(self):
return AtomicNumberTable(
torch.nonzero(self.model.z_to_onehot_tensor != -1).squeeze().tolist()
)

@property
def r_max(self):
return self.model.cutoff
8 changes: 4 additions & 4 deletions equitrain/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ def _predict(args, device=None):
r_force = torch.empty((0, 3), device=device)
r_stress = torch.empty((0, 3, 3), device=device)

data_loader = get_dataloader(args.predict_file, args)

model = get_model(args)

data_loader = get_dataloader(
args, args.predict_file, model.atomic_numbers, model.r_max
)

for data_list in data_loader:
for data in data_list:
y_pred = model(data)
Expand All @@ -150,8 +152,6 @@ def predict(args):

if args.predict_file is None:
raise ValueError('--predict-file is a required argument')
if args.statistics_file is None:
raise ValueError('--statistics-file is a required argument')
if args.model is None:
raise ValueError('--model is a required argument')

Expand Down
10 changes: 4 additions & 6 deletions equitrain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ def _train_with_accelerator(args, accelerator: Accelerator):
)
logger.log(1, ArgsFormatter(args))

""" Network """
model = get_model(args, logger=logger)

""" Data Loader """
train_loader, val_loader, test_loader = get_dataloaders(
args, accelerator, logger=logger
args, model.atomic_numbers, model.r_max, accelerator
)

""" Network """
model = get_model(args, logger=logger)

""" Optimizer and LR Scheduler """
optimizer = create_optimizer(args, model)
lr_scheduler = create_scheduler(args, optimizer)
Expand Down Expand Up @@ -467,8 +467,6 @@ def train(args):
raise ArgumentError('--train-file is a required argument')
if args.valid_file is None:
raise ArgumentError('--valid-file is a required argument')
if args.statistics_file is None:
raise ArgumentError('--statistics-file is a required argument')
if args.output_dir is None:
raise ArgumentError('--output-dir is a required argument')
if args.model is None:
Expand Down
1 change: 0 additions & 1 deletion resources/training/mace-alex-mptraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def main():

args.train_file = data_dir / 'train.h5'
args.valid_file = data_dir / 'valid.h5'
args.statistics_file = data_dir / 'statistics.json'
args.output_dir = 'result'
args.model = 'mace-initial.model'
args.model_wrapper = 'mace'
Expand Down
1 change: 0 additions & 1 deletion tests/test_predict_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ def test_mace_predict():
args = get_args_parser_predict().parse_args()

args.predict_file = 'data/valid.h5'
args.statistics_file = 'data/statistics.json'
args.batch_size = 5
args.model = MaceWrapper(args)

Expand Down
1 change: 0 additions & 1 deletion tests/test_train_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def test_train_mace():
args.train_file = 'data/train.h5'
args.valid_file = 'data/valid.h5'
args.test_file = 'data/train.h5'
args.statistics_file = 'data/statistics.json'
args.output_dir = 'test_train_mace'
args.model = MaceWrapper(args)

Expand Down
23 changes: 20 additions & 3 deletions tests/test_train_sevennet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
from equitrain import get_args_parser_train, train
from equitrain.data import Statistics
from equitrain.utility_test import SevennetWrapper


def test_train_mace():
def test_sevennet_atomic_numbers():
args = get_args_parser_train().parse_args()

statistics = Statistics.load('data/statistics.json')

model = SevennetWrapper(
args,
filename_config='test_train_sevennet.yaml',
filename_statistics='data/statistics.json',
)

assert model.atomic_numbers == statistics.atomic_numbers, (
'atomic numbers do not match'
)


def test_train_sevennet():
args = get_args_parser_train().parse_args()

args.train_file = 'data/train.h5'
args.valid_file = 'data/valid.h5'
args.test_file = 'data/train.h5'
args.statistics_file = 'data/statistics.json'
args.output_dir = 'test_train_sevennet'
args.model = SevennetWrapper(
args,
Expand All @@ -27,4 +43,5 @@ def test_train_mace():


if __name__ == '__main__':
test_train_mace()
test_sevennet_atomic_numbers()
test_train_sevennet()

0 comments on commit 3188706

Please sign in to comment.