diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 5f7c96cd..df800d8e 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -18,7 +18,7 @@ jobs: - name: Black Check uses: psf/black@stable with: - version: "22.3.0" + version: "24.4.2" flake8: runs-on: ubuntu-latest @@ -29,7 +29,7 @@ jobs: python-version: '3.x' - name: Install flake8 run: | - pip install flake8==7.0.0 + pip install flake8==7.1.0 - name: run flake8 run: | flake8 . --count --show-source --statistics diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0fb33150..b09cb076 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,6 +31,7 @@ jobs: run: | python -m pip install --upgrade pip pip install setuptools wheel + if [ ${TORCH} = "1.13.1" ]; then pip install numpy==1.*; fi # older torch versions fail with numpy 2 pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install h5py scikit-learn # install packages that aren't required dependencies but that the tests do need pip install --upgrade-strategy only-if-needed . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a5f9bb9..4d4b9abb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,11 +4,11 @@ fail_fast: true repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.4.2 hooks: - id: black - - repo: https://gitlab.com/pycqa/flake8 - rev: 4.0.1 + - repo: https://github.com/pycqa/flake8 + rev: 7.1.0 hooks: - id: flake8 diff --git a/CHANGELOG.md b/CHANGELOG.md index c60ed185..e1c228d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. + ## Unreleased + +## [0.6.1] - 2024-7-9 +### Added +- add support for equivariance testing of arbitrary Cartesian tensor outputs +- [Breaking] use entry points for `nequip.extension`s (e.g. for field registration) +- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin` +- Allow `n_train` and `n_val` to be specified as percentages of datasets. +- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases) + +### Changed +- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported + +### Fixed +- Fixed `flake8` install location in `pre-commit-config.yaml` + + ## [0.6.0] - 2024-5-10 ### Added - add Tensorboard as logger option diff --git a/configs/full.yaml b/configs/full.yaml index d13ff041..765d90d8 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -181,6 +181,9 @@ save_ema_checkpoint_freq: -1 # training n_train: 100 # number of training data n_val: 50 # number of validation data +# alternatively, n_train and n_val can be set as percentages of the dataset size: +# n_train: 70% # 70% of dataset +# n_val: 30% # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set) learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better validation_batch_size: 10 # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training. diff --git a/nequip/__init__.py b/nequip/__init__.py index 3a8d6d5c..ce145b41 100644 --- a/nequip/__init__.py +++ b/nequip/__init__.py @@ -1 +1,38 @@ +import sys + from ._version import __version__ # noqa: F401 + +import packaging.version + +import torch +import warnings + +# torch version checks +torch_version = packaging.version.parse(torch.__version__) + +# only allow 1.11*, 1.13* or higher (no 1.12.*) +assert (torch_version == packaging.version.parse("1.11")) or ( + torch_version >= packaging.version.parse("1.13") +), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found" + +# warn if using 1.13* or 2.0.* +if packaging.version.parse("1.13.0") <= torch_version: + warnings.warn( + f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." + ) + + +# Load all installed nequip extension packages +# This allows installed extensions to register themselves in +# the nequip infrastructure with calls like `register_fields` + +# see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +_DISCOVERED_NEQUIP_EXTENSION = entry_points(group="nequip.extension") +for ep in _DISCOVERED_NEQUIP_EXTENSION: + if ep.name == "init_always": + ep.load() diff --git a/nequip/_version.py b/nequip/_version.py index 8e22989a..6c1533d0 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.6.0" +__version__ = "0.6.1" diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 70c8fd2e..805d0cf5 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -10,7 +10,6 @@ import os import numpy as np -import ase.neighborlist import ase from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties @@ -18,6 +17,7 @@ import torch import e3nn.o3 +from e3nn.io import CartesianTensor from . import AtomicDataDict from ._util import _TORCH_INTEGER_DTYPES @@ -26,6 +26,7 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] +# === Key Registration === _DEFAULT_LONG_FIELDS: Set[str] = { AtomicDataDict.EDGE_INDEX_KEY, @@ -61,10 +62,15 @@ AtomicDataDict.CELL_KEY, AtomicDataDict.BATCH_PTR_KEY, } +_DEFAULT_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = { + AtomicDataDict.STRESS_KEY: "ij=ji", + AtomicDataDict.VIRIAL_KEY: "ij=ji", +} _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) _LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) +_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = dict(_DEFAULT_CARTESIAN_TENSOR_FIELDS) def register_fields( @@ -72,6 +78,7 @@ def register_fields( edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], long_fields: Sequence[str] = [], + cartesian_tensor_fields: Dict[str, str] = {}, ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -83,18 +90,36 @@ def register_fields( edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) long_fields: set = set(long_fields) - allfields = node_fields.union(edge_fields, graph_fields) - assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) + + # error checking: prevents registering fields as contradictory types + # potentially unregistered fields + assert len(node_fields.intersection(edge_fields)) == 0 + assert len(node_fields.intersection(graph_fields)) == 0 + assert len(edge_fields.intersection(graph_fields)) == 0 + # already registered fields + assert len(_NODE_FIELDS.intersection(edge_fields)) == 0 + assert len(_NODE_FIELDS.intersection(graph_fields)) == 0 + assert len(_EDGE_FIELDS.intersection(node_fields)) == 0 + assert len(_EDGE_FIELDS.intersection(graph_fields)) == 0 + assert len(_GRAPH_FIELDS.intersection(edge_fields)) == 0 + assert len(_GRAPH_FIELDS.intersection(node_fields)) == 0 + + # check that Cartesian tensor fields to add are rank-2 (higher ranks not supported) + for cart_tensor_key in cartesian_tensor_fields: + cart_tensor_rank = len( + CartesianTensor(cartesian_tensor_fields[cart_tensor_key]).indices + ) + if cart_tensor_rank != 2: + raise NotImplementedError( + f"Only rank-2 tensor data processing supported, but got {cart_tensor_key} is rank {cart_tensor_rank}. Consider raising a GitHub issue if higher-rank tensor data processing is desired." + ) + + # update fields _NODE_FIELDS.update(node_fields) _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) _LONG_FIELDS.update(long_fields) - if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( - len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) - ): - raise ValueError( - "At least one key was registered as more than one of node, edge, or graph!" - ) + _CARTESIAN_TENSOR_FIELDS.update(cartesian_tensor_fields) def deregister_fields(*fields: Sequence[str]) -> None: @@ -109,9 +134,16 @@ def deregister_fields(*fields: Sequence[str]) -> None: assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field" assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field" assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_LONG_FIELDS, "Cannot deregister built-in field" + assert ( + f not in _DEFAULT_CARTESIAN_TENSOR_FIELDS + ), "Cannot deregister built-in field" + _NODE_FIELDS.discard(f) _EDGE_FIELDS.discard(f) _GRAPH_FIELDS.discard(f) + _LONG_FIELDS.discard(f) + _CARTESIAN_TENSOR_FIELDS.pop(f, None) def _register_field_prefix(prefix: str) -> None: @@ -125,6 +157,9 @@ def _register_field_prefix(prefix: str) -> None: ) +# === AtomicData === + + def _process_dict(kwargs, ignore_fields=[]): """Convert a dict of data into correct dtypes/shapes according to key""" # Deal with _some_ dtype issues @@ -449,17 +484,40 @@ def from_ase( cell = kwargs.pop("cell", atoms.get_cell()) pbc = kwargs.pop("pbc", atoms.pbc) - # handle ASE-style 6 element Voigt order stress - for key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY): - if key in add_fields: - if add_fields[key].shape == (3, 3): - # it's already 3x3, do nothing else - pass - elif add_fields[key].shape == (6,): - # it's Voigt order - add_fields[key] = voigt_6_to_full_3x3_stress(add_fields[key]) + # IMPORTANT: the following reshape logic only applies to rank-2 Cartesian tensor fields + for key in add_fields: + if key in _CARTESIAN_TENSOR_FIELDS: + # enforce (3, 3) shape for graph fields, e.g. stress, virial + if key in _GRAPH_FIELDS: + # handle ASE-style 6 element Voigt order stress + if key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY): + if add_fields[key].shape == (6,): + add_fields[key] = voigt_6_to_full_3x3_stress( + add_fields[key] + ) + if add_fields[key].shape == (3, 3): + # it's already 3x3, do nothing else + pass + elif add_fields[key].shape == (9,): + add_fields[key] = add_fields[key].reshape((3, 3)) + else: + raise RuntimeError( + f"bad shape for {key} registered as a Cartesian tensor graph field---please note that only rank-2 Cartesian tensors are currently supported" + ) + # enforce (N_atom, 3, 3) shape for node fields, e.g. Born effective charges + elif key in _NODE_FIELDS: + if add_fields[key].shape[1:] == (3, 3): + pass + elif add_fields[key].shape[1:] == (9,): + add_fields[key] = add_fields[key].reshape((-1, 3, 3)) + else: + raise RuntimeError( + f"bad shape for {key} registered as a Cartesian tensor node field---please note that only rank-2 Cartesian tensors are currently supported" + ) else: - raise RuntimeError(f"bad shape for {key}") + raise RuntimeError( + f"{key} registered as a Cartesian tensor field was not registered as either a graph or node field" + ) return cls.from_points( pos=atoms.positions, @@ -705,12 +763,21 @@ def without_nodes(self, which_nodes): assert _ERROR_ON_NO_EDGES in ("true", "false") _ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true" -_NEQUIP_MATSCIPY_NL: Final[bool] = os.environ.get("NEQUIP_MATSCIPY_NL", "false").lower() -assert _NEQUIP_MATSCIPY_NL in ("true", "false") -_NEQUIP_MATSCIPY_NL = _NEQUIP_MATSCIPY_NL == "true" +# use "ase" as default +# TODO: eventually, choose fastest as default +# NOTE: +# - vesin and matscipy do not support self-interaction +# - vesin does not allow for mixed pbcs +_NEQUIP_NL: Final[str] = os.environ.get("NEQUIP_NL", "ase").lower() -if _NEQUIP_MATSCIPY_NL: +if _NEQUIP_NL == "vesin": + from vesin import NeighborList as vesin_nl +elif _NEQUIP_NL == "matscipy": import matscipy.neighbours +elif _NEQUIP_NL == "ase": + import ase.neighborlist +else: + raise NotImplementedError(f"Unknown neighborlist NEQUIP_NL = {_NEQUIP_NL}") def neighbor_list_and_relative_vec( @@ -790,7 +857,24 @@ def neighbor_list_and_relative_vec( # ASE dependent part temp_cell = ase.geometry.complete_cell(temp_cell) - if _NEQUIP_MATSCIPY_NL: + if _NEQUIP_NL == "vesin": + assert strict_self_interaction and not self_interaction + # use same mixed pbc logic as + # https://github.com/Luthaf/vesin/blob/main/python/vesin/src/vesin/_ase.py + if pbc[0] and pbc[1] and pbc[2]: + periodic = True + elif not pbc[0] and not pbc[1] and not pbc[2]: + periodic = False + else: + raise ValueError( + "different periodic boundary conditions on different axes are not supported by vesin neighborlist, use ASE or matscipy" + ) + + first_idex, second_idex, shifts = vesin_nl( + cutoff=float(r_max), full_list=True + ).compute(points=temp_pos, box=temp_cell, periodic=periodic, quantities="ijS") + + elif _NEQUIP_NL == "matscipy": assert strict_self_interaction and not self_interaction first_idex, second_idex, shifts = matscipy.neighbours.neighbour_list( "ijS", @@ -799,7 +883,7 @@ def neighbor_list_and_relative_vec( positions=temp_pos, cutoff=float(r_max), ) - else: + elif _NEQUIP_NL == "ase": first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( "ijS", pbc, diff --git a/nequip/data/AtomicDataDict.py b/nequip/data/AtomicDataDict.py index f7713e6f..ba8c75d1 100644 --- a/nequip/data/AtomicDataDict.py +++ b/nequip/data/AtomicDataDict.py @@ -5,6 +5,7 @@ Authors: Albert Musaelian """ + from typing import Dict, Any import torch @@ -67,7 +68,10 @@ def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type: # (2) works on a Batch constructed from AtomicData pos = data[_keys.POSITIONS_KEY] edge_index = data[_keys.EDGE_INDEX_KEY] - edge_vec = pos[edge_index[1]] - pos[edge_index[0]] + # edge_vec = pos[edge_index[1]] - pos[edge_index[0]] + edge_vec = torch.index_select(pos, 0, edge_index[1]) - torch.index_select( + pos, 0, edge_index[0] + ) if _keys.CELL_KEY in data: # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero. # -1 gives a batch dim no matter what diff --git a/nequip/data/__init__.py b/nequip/data/__init__.py index 02c41d55..5cbbc853 100644 --- a/nequip/data/__init__.py +++ b/nequip/data/__init__.py @@ -8,6 +8,7 @@ _EDGE_FIELDS, _GRAPH_FIELDS, _LONG_FIELDS, + _CARTESIAN_TENSOR_FIELDS, ) from ._dataset import ( AtomicDataset, @@ -39,5 +40,6 @@ _EDGE_FIELDS, _GRAPH_FIELDS, _LONG_FIELDS, + _CARTESIAN_TENSOR_FIELDS, EMTTestDataset, ] diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 35b59dba..80c46a94 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -3,7 +3,7 @@ from nequip import data from nequip.data.transforms import TypeMapper -from nequip.data import AtomicDataset, register_fields +from nequip.data import AtomicDataset from nequip.utils import instantiate, get_w_prefix @@ -71,10 +71,6 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: # Build a TypeMapper from the config type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) - # Register fields: - # This might reregister fields, but that's OK: - instantiate(register_fields, all_args=config) - instance, _ = instantiate( class_name, prefix=prefix, diff --git a/nequip/data/_dataset/_ase_dataset.py b/nequip/data/_dataset/_ase_dataset.py index 3246d791..633b5e48 100644 --- a/nequip/data/_dataset/_ase_dataset.py +++ b/nequip/data/_dataset/_ase_dataset.py @@ -51,10 +51,12 @@ def _ase_dataset_reader( datas.append( ( global_index, - AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) - if global_index in include_frames - # in-memory dataset will ignore this later, but needed for indexing to work out - else None, + ( + AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) + if global_index in include_frames + # in-memory dataset will ignore this later, but needed for indexing to work out + else None + ), ) ) # Save to a tempfile--- diff --git a/nequip/data/_dataset/_base_datasets.py b/nequip/data/_dataset/_base_datasets.py index bda86734..933a87a6 100644 --- a/nequip/data/_dataset/_base_datasets.py +++ b/nequip/data/_dataset/_base_datasets.py @@ -416,7 +416,7 @@ def statistics( if field not in selectors: # this means field is not selected and so not available raise RuntimeError( - f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`" + f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such" ) arr = data_transformed[field] if field in _NODE_FIELDS: diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index edd04cbe..93f926d3 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -2,6 +2,7 @@ This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module. """ + import sys from typing import List diff --git a/nequip/data/dataloader.py b/nequip/data/dataloader.py index ea9c7fc9..9b95cd66 100644 --- a/nequip/data/dataloader.py +++ b/nequip/data/dataloader.py @@ -84,6 +84,7 @@ class PartialSampler(Sampler[int]): If `None`, defaults to `len(data_source)`. generator (Generator): Generator used in sampling. """ + data_source: Dataset num_samples_per_epoch: int shuffle: bool diff --git a/nequip/nn/_convnetlayer.py b/nequip/nn/_convnetlayer.py index 9e5437a8..6d339cab 100644 --- a/nequip/nn/_convnetlayer.py +++ b/nequip/nn/_convnetlayer.py @@ -149,9 +149,9 @@ def __init__( # updated with whatever the convolution outputs (which is a full graph module) self.irreps_out.update(self.conv.irreps_out) # but with the features updated by the nonlinearity - self.irreps_out[ - AtomicDataDict.NODE_FEATURES_KEY - ] = self.equivariant_nonlin.irreps_out + self.irreps_out[AtomicDataDict.NODE_FEATURES_KEY] = ( + self.equivariant_nonlin.irreps_out + ) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # save old features for resnet diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index ee0ce6f9..eb04d78a 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -20,6 +20,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module): out_field: the field in which to return the computed gradients. Defaults to ``f"d({of})/d({wrt})"`` for each field in ``wrt``. sign: either 1 or -1; the returned gradient is multiplied by this. """ + sign: float _negate: bool skip: bool @@ -119,6 +120,7 @@ class PartialForceOutput(GraphModuleMixin, torch.nn.Module): vectorize: the vectorize option to ``torch.autograd.functional.jacobian``, false by default since it doesn't work well. """ + vectorize: bool def __init__( @@ -183,6 +185,7 @@ class StressOutput(GraphModuleMixin, torch.nn.Module): func: the energy model to wrap do_forces: whether to compute forces as well """ + do_forces: bool def __init__( diff --git a/nequip/nn/_interaction_block.py b/nequip/nn/_interaction_block.py index f3164709..a9dcecd7 100644 --- a/nequip/nn/_interaction_block.py +++ b/nequip/nn/_interaction_block.py @@ -1,4 +1,5 @@ """ Interaction Block """ + from typing import Optional, Dict, Callable import torch diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index a0772df9..24d36a52 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -71,6 +71,8 @@ def _set_deploy_metadata(key: str, value) -> None: global _current_metadata if _current_metadata is None: pass # not deploying right now + elif key not in _ALL_METADATA_KEYS: + raise KeyError(f"{key} is not a registered model deployment metadata key") elif key in _current_metadata: raise RuntimeError(f"{key} already set in the deployment metadata") else: @@ -108,10 +110,23 @@ def load_deployed_model( f"{model_path} does not seem to be a deployed NequIP model file. Did you forget to deploy it using `nequip-deploy`? \n\n(Underlying error: {e})" ) # Confirm nequip made it - if metadata[NEQUIP_VERSION_KEY] == "": - raise ValueError( - f"{model_path} does not seem to be a deployed NequIP model file" - ) + if len(metadata[NEQUIP_VERSION_KEY]) == 0: + if len(metadata[JIT_BAILOUT_KEY]) != 0: + # In versions <0.6.0, there may have been a bug leading to empty "*_version" + # metadata keys. We can be pretty confident this is a NequIP model from + # those versions, though, if it stored "_jit_bailout_depth" + # https://github.com/mir-group/nequip/commit/2f43aa84542df733bbe38cb9d6cca176b0e98054 + # Likely addresses https://github.com/mir-group/nequip/issues/431 + warnings.warn( + f"{model_path} appears to be from a older (0.5.* or earlier) version of `nequip` " + "that pre-dates a variety of breaking changes. Please carefully check the " + "correctness of your results for unexpected behaviour, and consider re-deploying " + "your model using this current `nequip` installation." + ) + else: + raise ValueError( + f"{model_path} does not seem to be a deployed NequIP model file" + ) # Confirm its TorchScript assert isinstance(model, torch.jit.ScriptModule) # Make sure we're in eval mode @@ -127,11 +142,14 @@ def load_deployed_model( if metadata[DEFAULT_DTYPE_KEY] == "": # Default and model go together assert metadata[MODEL_DTYPE_KEY] == "" - # If there isn't a dtype, it should be older than 0.6.0: - assert packaging.version.parse( - metadata[NEQUIP_VERSION_KEY] - ) < packaging.version.parse("0.6.0") - # i.e. no value due to L85 above + # If there isn't a dtype, it should be older than 0.6.0---but + # this may not be reflected in the version fields (see above check) + # So we only check if it is available: + if len(metadata[NEQUIP_VERSION_KEY]) > 0: + assert packaging.version.parse( + metadata[NEQUIP_VERSION_KEY] + ) < packaging.version.parse("0.6.0") + # The old pre-0.6.0 defaults: metadata[DEFAULT_DTYPE_KEY] = "float32" metadata[MODEL_DTYPE_KEY] = "float32" diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 20382eef..b40c3a8a 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -383,9 +383,11 @@ def main(args=None, running_as_script: bool = True): if do_metrics: display_bar = context_stack.enter_context( tqdm( - bar_format="" - if prog.disable # prog.ncols doesn't exist if disabled - else ("{desc:." + str(prog.ncols) + "}"), + bar_format=( + "" + if prog.disable # prog.ncols doesn't exist if disabled + else ("{desc:." + str(prog.ncols) + "}") + ), disable=None, ) ) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 3d10049b..e83bd299 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -1,4 +1,5 @@ """ Train a network.""" + import logging import argparse import warnings @@ -7,7 +8,8 @@ # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. import numpy as np # noqa: F401 -from os.path import isdir +from os.path import exists, isdir +from shutil import rmtree from pathlib import Path import torch @@ -71,12 +73,21 @@ def main(args=None, running_as_script: bool = True): if running_as_script: set_up_script_logger(config.get("log", None), config.verbose) - found_restart_file = isdir(f"{config.root}/{config.run_name}") + found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth") if found_restart_file and not config.append: raise RuntimeError( f"Training instance exists at {config.root}/{config.run_name}; " "either set append to True or use a different root or runname" ) + elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"): + # output directory exists but no ``trainer.pth`` file, suggesting previous run crash during + # first training epoch (usually due to memory): + warnings.warn( + f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model " + f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will " + f"be started." + ) + rmtree(f"{config.root}/{config.run_name}") # for fresh new train if not found_restart_file: diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 6442c0d4..144d348d 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -83,7 +83,8 @@ def __call__( # zero the nan entries has_nan = self.ignore_nan and torch.isnan(ref.sum()) N = torch.bincount(ref_dict[AtomicDataDict.BATCH_KEY]) - N = N.reshape((-1, 1)) + # as many dimensions of size 1 as there are non-batch dimensions in the data + N = N.reshape((-1,) + (1,) * (pred.ndim - 1)) if has_nan: not_nan = (ref == ref).int() loss = self.func(pred, torch.nan_to_num(ref, nan=0.0)) * not_nan / N diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index bdfb4f17..2c257785 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -7,6 +7,7 @@ make an interface with ray """ + import sys import inspect import logging @@ -107,7 +108,7 @@ class Trainer: - "trainer_save.pth": all the training information. The file used for loading and restart For restart run, the default set up is to not append to the original folders and files. - The Output class will automatically build a folder call root/run_name + The Output class will automatically build a folder called ``root/run_name`` If append mode is on, the log file will be appended and the best model and last model will be overwritten. More examples can be found in tests/train/test_trainer.py @@ -157,9 +158,9 @@ class Trainer: batch_size (int): size of each batch validation_batch_size (int): batch size for evaluating the model for validation shuffle (bool): parameters for dataloader - n_train (int): # of frames for training + n_train (int, str): # of frames for training (as int, or as a percentage string) n_train_per_epoch (optional int): how many frames from `n_train` to use each epoch; see `PartialSampler`. When `None`, all `n_train` frames will be used each epoch. - n_val (int): # of frames for validation + n_val (int), str: # of frames for validation (as int, or as a percentage string) exclude_keys (list): fields from dataset to ignore. dataloader_num_workers (int): `num_workers` for the `DataLoader`s train_idcs (optional, list): list of frames to use for training @@ -250,9 +251,9 @@ def __init__( batch_size: int = 5, validation_batch_size: int = 5, shuffle: bool = True, - n_train: Optional[int] = None, + n_train: Optional[Union[int, str]] = None, n_train_per_epoch: Optional[int] = None, - n_val: Optional[int] = None, + n_val: Optional[Union[int, str]] = None, dataloader_num_workers: int = 0, train_idcs: Optional[list] = None, val_idcs: Optional[list] = None, @@ -754,7 +755,6 @@ def init_metrics(self): ) def train(self): - """Training""" if getattr(self, "dl_train", None) is None: raise RuntimeError("You must call `set_dataset()` before calling `train()`") @@ -1144,12 +1144,59 @@ def __del__(self): for i in range(len(logger.handlers)): logger.handlers.pop() + def _parse_n_train_n_val( + self, train_dataset_size: int, val_dataset_size: int + ) -> Tuple[int, int]: + # parse n_train and n_val (can be ints or str with percentage): + n_train_n_val = [] + for n_name, dataset_size in ( + ("n_train", train_dataset_size), + ("n_val", val_dataset_size), + ): + n = getattr(self, n_name) + if isinstance(n, str) and "%" in n: + n_train_n_val.append( + (float(n.rstrip("%")) / 100) * dataset_size + ) # convert to float first + elif isinstance(n, int): + n_train_n_val.append(n) + else: + raise ValueError( + f"Invalid value/type for {n_name}: {n} -- must be either int or str with %!" + ) + + floored_n_train_n_val = [int(n) for n in n_train_n_val] + for n, n_name in zip(floored_n_train_n_val, ["n_train", "n_val"]): + if n < 1: + raise ValueError(f"{n_name} must be at least 1! Got {n}.") + + # if n_train and n_val were both set as percentages which summed to 100%, make sure that sum of + # floored values comes to 100% of dataset size (i.e. that flooring doesn't omit a frame) + if ( + train_dataset_size == val_dataset_size + and isinstance(self.n_train, str) + and isinstance(self.n_val, str) + and np.isclose( + float(self.n_train.strip("%")) + float(self.n_val.strip("%")), 100 + ) + ): + if ( + sum(floored_n_train_n_val) != train_dataset_size + ): # one frame was cut, add to larger of the + # two float values (i.e. round up the percentage which gave a >= x.5 float value) + floored_n_train_n_val[ + np.argmax(n_train_n_val) + ] += train_dataset_size - sum(floored_n_train_n_val) + + return tuple(floored_n_train_n_val) + def set_dataset( self, dataset: AtomicDataset, validation_dataset: Optional[AtomicDataset] = None, ) -> None: - """Set the dataset(s) used by this trainer. + """ + Set the dataset(s) used by this trainer. Training and validation datasets will be sampled from them in accordance with the trainer's parameters. @@ -1163,7 +1210,10 @@ def set_dataset( if validation_dataset is None: # Sample both from `dataset`: total_n = len(dataset) - if (self.n_train + self.n_val) > total_n: + n_train, n_val = self._parse_n_train_n_val( + train_dataset_size=total_n, val_dataset_size=total_n + ) + if (n_train + n_val) > total_n: raise ValueError( "too little data for training and validation. please reduce n_train and n_val" ) @@ -1177,25 +1227,29 @@ def set_dataset( f"splitting mode {self.train_val_split} not implemented" ) - self.train_idcs = idcs[: self.n_train] - self.val_idcs = idcs[self.n_train : self.n_train + self.n_val] + self.train_idcs = idcs[:n_train] + self.val_idcs = idcs[n_train : n_train + n_val] else: - if self.n_train > len(dataset): + n_train, n_val = self._parse_n_train_n_val( + train_dataset_size=len(dataset), + val_dataset_size=len(validation_dataset), + ) + if n_train > len(dataset): raise ValueError("Not enough data in dataset for requested n_train") - if self.n_val > len(validation_dataset): + if n_val > len(validation_dataset): raise ValueError( "Not enough data in validation dataset for requested n_val" ) if self.train_val_split == "random": self.train_idcs = torch.randperm( len(dataset), generator=self.dataset_rng - )[: self.n_train] + )[:n_train] self.val_idcs = torch.randperm( len(validation_dataset), generator=self.dataset_rng - )[: self.n_val] + )[:n_val] elif self.train_val_split == "sequential": - self.train_idcs = torch.arange(self.n_train) - self.val_idcs = torch.arange(self.n_val) + self.train_idcs = torch.arange(n_train) + self.val_idcs = torch.arange(n_val) else: raise NotImplementedError( f"splitting mode {self.train_val_split} not implemented" diff --git a/nequip/utils/_global_options.py b/nequip/utils/_global_options.py index bc5bc2d9..3a08e55e 100644 --- a/nequip/utils/_global_options.py +++ b/nequip/utils/_global_options.py @@ -7,9 +7,7 @@ import e3nn import e3nn.util.jit -from nequip.data import register_fields from .misc import dtype_from_name -from .auto_init import instantiate from .test import set_irreps_debug from .config import Config @@ -53,12 +51,6 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: # Temporary warning due to unresolved upstream issue torch_version = version.parse(torch.__version__) - if torch_version < version.parse("1.11"): - warnings.warn("We currently recommend the use of PyTorch 1.11") - elif torch_version > version.parse("1.11"): - warnings.warn( - "!! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." - ) if torch_version >= version.parse("1.11"): # PyTorch >= 1.11 @@ -122,6 +114,4 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: e3nn.set_optimization_defaults(**config.get("e3nn_optimization_defaults", {})) - # Register fields: - instantiate(register_fields, all_args=config) return diff --git a/nequip/utils/config.py b/nequip/utils/config.py index ca79f576..c160a135 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -34,6 +34,7 @@ If a parameter is updated, the updated value will be formatted back to the same type. """ + from typing import Set, Dict, Any, List import inspect diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 53b09fcf..ffe60b19 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -1,6 +1,7 @@ """ utilities that involve file searching and operations (i.e. save/load) """ + from typing import Union, List, Tuple, Optional, Callable import sys import logging diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 7c0bde3f..4c7450b5 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -2,7 +2,7 @@ import torch from e3nn import o3 -from e3nn.util.test import equivariance_error, FLOAT_TOLERANCE +from e3nn.util.test import equivariance_error from nequip.nn import GraphModuleMixin, GraphModel from nequip.data import ( @@ -10,8 +10,19 @@ AtomicDataDict, _NODE_FIELDS, _EDGE_FIELDS, + _CARTESIAN_TENSOR_FIELDS, ) - +from nequip.utils.misc import dtype_from_name + +# The default float tolerance +FLOAT_TOLERANCE = { + t: torch.as_tensor(v, dtype=dtype_from_name(t)) + for t, v in {"float32": 1e-3, "float64": 1e-10}.items() +} +# Allow lookup by name or dtype object: +for t, v in list(FLOAT_TOLERANCE.items()): + FLOAT_TOLERANCE[dtype_from_name(t)] = v +del t, v # This has to be somewhat large because of float32 sum reductions over many edges/atoms PERMUTATION_FLOAT_TOLERANCE = {torch.float32: 1e-4, torch.float64: 1e-10} @@ -45,9 +56,11 @@ def assert_permutation_equivariant( if tolerance is None: atol = PERMUTATION_FLOAT_TOLERANCE[ - func.model_dtype - if isinstance(func, GraphModel) - else torch.get_default_dtype() + ( + func.model_dtype + if isinstance(func, GraphModel) + else torch.get_default_dtype() + ) ] else: atol = tolerance @@ -209,17 +222,26 @@ def assert_AtomicData_equivariant( # must be this to actually rotate it when flattened irps[AtomicDataDict.CELL_KEY] = "3x1o" - stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY) - for k in stress_keys: + cartesian_keys = _CARTESIAN_TENSOR_FIELDS.keys() + for k in ( + AtomicDataDict.STRESS_KEY, + AtomicDataDict.VIRIAL_KEY, + ): # TODO should this be cartesian_keys? irreps_in.pop(k, None) - if any(k in irreps_out for k in stress_keys): + if any(k in irreps_out for k in cartesian_keys): from e3nn.io import CartesianTensor - stress_cart_tensor = CartesianTensor("ij=ji") # stress is symmetric - stress_rtp = stress_cart_tensor.reduced_tensor_products().to(device, dtype) - # symmetric 3x3 cartesian tensor as irreps - for k in stress_keys: - irreps_out[k] = stress_cart_tensor + cartesian_tensor = { + k: CartesianTensor(_CARTESIAN_TENSOR_FIELDS[k]) + for k in cartesian_keys + if k in irreps_out + } + cartesian_rtp = { + k: ct.reduced_tensor_products().to(device, dtype) + for k, ct in cartesian_tensor.items() + } + for k, ct in cartesian_tensor.items(): + irreps_out[k] = ct def wrapper(*args): arg_dict = {k: v for k, v in zip(irreps_in, args)} @@ -238,12 +260,12 @@ def wrapper(*args): val = output[key] assert val.shape[-2:] == (3, 3) output[key] = val.reshape(val.shape[:-2] + (9,)) - # stress is also a special case, + # cartesian tensors like stress are also a special case, # we need it to be decomposed into irreps for equivar testing - for k in stress_keys: + for k in cartesian_keys: if k in output: - output[k] = stress_cart_tensor.from_cartesian( - output[k], rtp=stress_rtp.to(output[k].dtype) + output[k] = cartesian_tensor[k].from_cartesian( + output[k], rtp=cartesian_rtp[k].to(output[k].dtype) ) return [output[k] for k in irreps_out] diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index a2dc103d..aa716d07 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -12,12 +12,11 @@ import torch -from nequip.utils.test import set_irreps_debug +from nequip.utils.test import set_irreps_debug, FLOAT_TOLERANCE from nequip.data import AtomicData, ASEDataset from nequip.data.transforms import TypeMapper from nequip.utils.torch_geometric import Batch from nequip.utils._global_options import _set_global_options -from nequip.utils.misc import dtype_from_name # Sometimes we run parallel using pytest-xdist, and want to be able to use # as many GPUs as are available @@ -42,12 +41,6 @@ # Test parallelization, but don't waste time spawning tons of workers if lots of cores available os.environ["NEQUIP_NUM_TASKS"] = "2" -# The default float tolerance -FLOAT_TOLERANCE = { - t: torch.as_tensor(v, dtype=dtype_from_name(t)) - for t, v in {"float32": 1e-3, "float64": 1e-10}.items() -} - @pytest.fixture(scope="session", autouse=True, params=["float32", "float64"]) def float_tolerance(request): diff --git a/nequip/utils/unittests/model_tests.py b/nequip/utils/unittests/model_tests.py index 37e9dcb6..b9d6790b 100644 --- a/nequip/utils/unittests/model_tests.py +++ b/nequip/utils/unittests/model_tests.py @@ -228,7 +228,13 @@ def test_equivariance(self, model, atomic_batch, device): instance, out_fields = model instance = instance.to(device=device) atomic_batch = atomic_batch.to(device=device) - assert_AtomicData_equivariant(func=instance, data_in=atomic_batch) + assert_AtomicData_equivariant( + func=instance, + data_in=atomic_batch, + e3_tolerance={torch.float32: 1e-3, torch.float64: 1e-8}[ + torch.get_default_dtype() + ], + ) def test_embedding_cutoff(self, model, config, device): instance, out_fields = model @@ -449,10 +455,12 @@ def test_partial_forces(self, config, atomic_batch, device, strict_locality): assert torch.allclose( output[k], output_partial[k], - atol=1e-8 - if k == AtomicDataDict.TOTAL_ENERGY_KEY - and torch.get_default_dtype() == torch.float64 - else 1e-5, + atol=( + 1e-8 + if k == AtomicDataDict.TOTAL_ENERGY_KEY + and torch.get_default_dtype() == torch.float64 + else 1e-5 + ), ) else: assert torch.equal(output[k], output_partial[k]) diff --git a/setup.py b/setup.py index 6ca9e3cf..af851a0e 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "contextlib2;python_version<'3.7'", # backport of nullcontext 'contextvars;python_version<"3.7"', # backport of contextvars for savenload "typing_extensions;python_version<'3.8'", # backport of Final + "importlib_metadata;python_version<'3.10'", # backport of importlib "torch-runstats>=0.2.0", "torch-ema>=0.3.0", ], diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index 4dd9bce0..4a1388ed 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -171,9 +171,11 @@ def runit(params: dict): assert np.allclose( err, 0.0, - atol=1e-8 - if true_identity - else (1e-2 if metric.startswith("e") else 1e-4), + atol=( + 1e-8 + if true_identity + else (1e-2 if metric.startswith("e") else 1e-4) + ), ), f"Metric `{metric}` wasn't zero!" elif builder == ConstFactorModel: # TODO: check comperable to naive numpy compute diff --git a/tests/unit/model/test_pair/test_zbl.py b/tests/unit/model/test_pair/test_zbl.py index b862b624..c578cb6c 100644 --- a/tests/unit/model/test_pair/test_zbl.py +++ b/tests/unit/model/test_pair/test_zbl.py @@ -81,7 +81,7 @@ def test_lammps_repro(self, config): # $ lmp -in zbl_data.lmps # $ python -c "import numpy as np; d = np.loadtxt('zbl.dat', skiprows=1); np.save('zbl.npy', d)" refdata = np.load(Path(__file__).parent / "zbl.npy") - for (r, Zi, Zj, pe, fxi, fxj) in refdata: + for r, Zi, Zj, pe, fxi, fxj in refdata: if r >= r_max: continue atoms.positions[1, 0] = r diff --git a/tests/unit/trainer/test_trainer.py b/tests/unit/trainer/test_trainer.py index 197f3897..9e2bd64e 100644 --- a/tests/unit/trainer/test_trainer.py +++ b/tests/unit/trainer/test_trainer.py @@ -1,6 +1,7 @@ """ Trainer tests """ + import pytest import numpy as np @@ -45,12 +46,14 @@ def dummy_builder(): ) -@pytest.fixture(scope="function") -def trainer(float_tolerance): +def create_trainer(float_tolerance, **kwargs): """ - Generate a class instance with minimal configurations + Generate a class instance with minimal configurations, + with the option to modify the configurations using + kwargs. """ conf = minimal_config.copy() + conf.update(kwargs) conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] model = model_from_config(conf) with tempfile.TemporaryDirectory(prefix="output") as path: @@ -59,6 +62,14 @@ def trainer(float_tolerance): yield c +@pytest.fixture(scope="function") +def trainer(float_tolerance): + """ + Generate a class instance with minimal configurations. + """ + yield from create_trainer(float_tolerance) + + class TestTrainerSetUp: """ test initialization @@ -158,6 +169,134 @@ def test_split(self, trainer, nequip_dataset, mode): else: assert n_samples == trainer.n_train + @pytest.mark.parametrize("mode", ["random", "sequential"]) + @pytest.mark.parametrize( + "n_train_percent, n_val_percent", [("75%", "15%"), ("20%", "30%")] + ) + def test_split_w_percent_n_train_n_val( + self, nequip_dataset, mode, float_tolerance, n_train_percent, n_val_percent + ): + """ + Test case where n_train and n_val are given as percentage of the + dataset size, and here they don't sum to 100%. + """ + # nequip_dataset has 8 frames, so setting n_train to 75% and n_val to 15% should give 6 and 1 + # frames respectively. Note that summed percentages don't have to be 100% + trainer_w_percent_n_train_n_val = next( + create_trainer( + float_tolerance=float_tolerance, + n_train=n_train_percent, + n_val=n_val_percent, + ) + ) + trainer_w_percent_n_train_n_val.train_val_split = mode + trainer_w_percent_n_train_n_val.set_dataset(nequip_dataset) + for epoch_i in range(3): + trainer_w_percent_n_train_n_val.dl_train_sampler.step_epoch(epoch_i) + n_samples: int = 0 + n_val_samples: int = 0 + for i, batch in enumerate(trainer_w_percent_n_train_n_val.dl_train): + n_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + if trainer_w_percent_n_train_n_val.n_train_per_epoch is not None: + assert n_samples == trainer_w_percent_n_train_n_val.n_train_per_epoch + else: + assert ( + n_samples != trainer_w_percent_n_train_n_val.n_train + ) # n_train now a percentage + assert trainer_w_percent_n_train_n_val.n_train == n_train_percent # 75% + assert n_samples == int( + (float(n_train_percent.strip("%")) / 100) * len(nequip_dataset) + ) # 6 + assert trainer_w_percent_n_train_n_val.n_val == n_val_percent # 15% + + for i, batch in enumerate(trainer_w_percent_n_train_n_val.dl_val): + n_val_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + + assert ( + n_val_samples != trainer_w_percent_n_train_n_val.n_val + ) # n_val now a percentage + assert trainer_w_percent_n_train_n_val.n_val == n_val_percent # 15% + assert n_val_samples == int( + (float(n_val_percent.strip("%")) / 100) * len(nequip_dataset) + ) # 1 (floored) + + @pytest.mark.parametrize("mode", ["random", "sequential"]) + @pytest.mark.parametrize( + "n_train_percent, n_val_percent", [("70%", "30%"), ("55%", "45%")] + ) + def test_split_w_percent_n_train_n_val_flooring( + self, nequip_dataset, mode, float_tolerance, n_train_percent, n_val_percent + ): + """ + Test case where n_train and n_val are given as percentage of the + dataset size, summing to 100% but with a split that gives + non-integer numbers of frames for n_train and n_val. + (i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, + so final n_train is 6 and n_val is 2) + """ + # nequip_dataset has 8 frames, so n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, + # so final n_train is 6 and n_val is 2 + trainer_w_percent_n_train_n_val_flooring = next( + create_trainer( + float_tolerance=float_tolerance, + n_train=n_train_percent, + n_val=n_val_percent, + ) + ) + trainer_w_percent_n_train_n_val_flooring.train_val_split = mode + trainer_w_percent_n_train_n_val_flooring.set_dataset(nequip_dataset) + for epoch_i in range(3): + trainer_w_percent_n_train_n_val_flooring.dl_train_sampler.step_epoch( + epoch_i + ) + n_samples: int = 0 + n_val_samples: int = 0 + for i, batch in enumerate( + trainer_w_percent_n_train_n_val_flooring.dl_train + ): + n_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + if trainer_w_percent_n_train_n_val_flooring.n_train_per_epoch is not None: + assert ( + n_samples + == trainer_w_percent_n_train_n_val_flooring.n_train_per_epoch + ) + else: + assert ( + n_samples != trainer_w_percent_n_train_n_val_flooring.n_train + ) # n_train now a percentage + assert ( + trainer_w_percent_n_train_n_val_flooring.n_train == n_train_percent + ) # 70% + # _not_ equal to the bare floored value now: + assert n_samples != int( + (float(n_train_percent.strip("%")) / 100) * len(nequip_dataset) + ) # 5 + assert ( + n_samples + == int( # equal to floored value plus 1 + (float(n_train_percent.strip("%")) / 100) * len(nequip_dataset) + ) + + 1 + ) # 6 + assert ( + trainer_w_percent_n_train_n_val_flooring.n_val == n_val_percent + ) # 30% + + for i, batch in enumerate(trainer_w_percent_n_train_n_val_flooring.dl_val): + n_val_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + + assert ( + n_val_samples != trainer_w_percent_n_train_n_val_flooring.n_val + ) # n_val now a percentage + assert ( + trainer_w_percent_n_train_n_val_flooring.n_val == n_val_percent + ) # 30% + assert n_val_samples == int( + (float(n_val_percent.strip("%")) / 100) * len(nequip_dataset) + ) # 2 (floored) + + assert n_samples + n_val_samples == len(nequip_dataset) # 100% coverage + class TestTrain: def test_train(self, trainer, nequip_dataset): diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index 35ae7b68..22025ffe 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -1,6 +1,7 @@ """ Config tests """ + import pytest from os import remove diff --git a/tests/unit/utils/test_output.py b/tests/unit/utils/test_output.py index cdc7b4ac..ec79dd1f 100644 --- a/tests/unit/utils/test_output.py +++ b/tests/unit/utils/test_output.py @@ -1,6 +1,7 @@ """ Config tests """ + import pytest import tempfile