diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 070f2557..5f7c96cd 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -29,7 +29,7 @@ jobs: python-version: '3.x' - name: Install flake8 run: | - pip install flake8==4.0.1 + pip install flake8==7.0.0 - name: run flake8 run: | flake8 . --count --show-source --statistics diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1f835e90..0fb33150 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: python-version: [3.9] - torch-version: [1.11.0, 1.12.1] + torch-version: [1.13.1, "2.*"] steps: - uses: actions/checkout@v2 @@ -32,6 +32,7 @@ jobs: python -m pip install --upgrade pip pip install setuptools wheel pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install h5py scikit-learn # install packages that aren't required dependencies but that the tests do need pip install --upgrade-strategy only-if-needed . - name: Install pytest run: | diff --git a/.github/workflows/tests_develop.yml b/.github/workflows/tests_develop.yml index 2c23350c..d399e426 100644 --- a/.github/workflows/tests_develop.yml +++ b/.github/workflows/tests_develop.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: python-version: [3.9] - torch-version: [1.12.1] + torch-version: ["2.*"] steps: - uses: actions/checkout@v2 @@ -32,6 +32,7 @@ jobs: python -m pip install --upgrade pip pip install setuptools wheel pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install h5py scikit-learn # install packages that aren't required dependencies but that the tests do need pip install --upgrade-strategy only-if-needed . - name: Install pytest run: | diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..70205bbd --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,20 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file for details + +# Required +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.9" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index cf50972d..c60ed185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,51 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. +## Unreleased + +## [0.6.0] - 2024-5-10 +### Added +- add Tensorboard as logger option +- [Breaking] Refactor overall model logic into `GraphModel` top-level module +- [Breaking] Added `model_dtype` +- `BATCH_PTR_KEY` in `AtomicDataDict` +- `AtomicInMemoryDataset.rdf()` and `examples/rdf.py` +- `type_to_chemical_symbol` +- Pair potential terms +- `nequip-evaluate --output-fields-from-original-dataset` +- Error (or warn) on unused options in YAML that likely indicate typos +- `dataset_*_absmax` statistics option +- `HDF5Dataset` (#227) +- `include_file_as_baseline_config` for simple modifications of existing configs +- `nequip-deploy --using-dataset` to support data-dependent deployment steps +- Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574) +- `start_of_epoch_callbacks` +- `nequip.train.callbacks.loss_schedule.SimpleLossSchedule` for changing the loss coefficients at specified epochs +- `nequip-deploy build --checkpoint` and `--override` to avoid many largely duplicated YAML files +- matscipy neighborlist support enabled with `NEQUIP_MATSCIPY_NL` environment variable + +### Changed +- Always require explicit `seed` +- [Breaking] Set `dataset_seed` to `seed` if it is not explicitly provided +- Don't log as often by default +- [Breaking] Default nonlinearities are `silu` (`e`) and `tanh` (`o`) +- Will not reproduce previous versions' data shuffling order (for all practical purposes this does not matter, the `shuffle` option is unchanged) +- [Breaking] `default_dtype` defaults to `float64` (`model_dtype` default `float32`, `allow_tf32: true` by default--- see https://arxiv.org/abs/2304.10061) +- `nequip-benchmark` now only uses `--n-data` frames to build the model +- [Breaking] By default models now use `StressForceOutput`, not `ForceOutput` +- Added `edge_energy` to `ALL_ENERGY_KEYS` subjecting it to global rescale + +### Fixed +- Work with `wandb>=0.13.8` +- Better error for standard deviation with too few data +- `load_model_state` GPU -> CPU +- No negative volumes in rare cases + +### Removed +- [Breaking] `fixed_fields` machinery (`npz_fixed_field_keys` is still supported, but through a more straightforward implementation) +- Default run name/WandB project name of `NequIP`, they must now always be provided explicitly +- [Breaking] Removed `_params` as an allowable subconfiguration suffix (i.e. instead of `optimizer_params` now only `optimizer_kwargs` is valid, not both) +- [Breaking] Removed `per_species_rescale_arguments_in_dataset_units` ## [0.5.6] - 2022-12-19 ### Added @@ -14,6 +59,7 @@ Most recent change on the bottom. - `nequip-benchmark --no-compile` and `--verbose` and `--memory-summary` - `nequip-benchmark --pdb` for debugging model (builder) errors - More information in `nequip-deploy info` +- GPU OOM offloading mode ### Changed - Minimum e3nn is now 0.4.4 diff --git a/README.md b/README.md index 7500f9ec..9c983f9f 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,13 @@ NequIP is an open-source code for building E(3)-equivariant interatomic potentia NequIP requires: * Python >= 3.7 -* PyTorch >= 1.8, !=1.9, <=1.11.*. PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. +* PyTorch == `1.11.*` or `1.13.*` or later (do **not** use `1.12`). (Some users have observed silent issues with PyTorch 2+, as reported in #311. Please report any similar issues you encounter.) PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. + +**You must install PyTorch before installing NequIP, however it is not marked as a dependency of `nequip` to prevent `pip` from trying to overwrite your PyTorch installation.** To install: -* We use [Weights&Biases](https://wandb.ai) to keep track of experiments. This is not a strict requirement — you can use our package without it — but it may make your life easier. If you want to use it, create an account [here](https://wandb.ai) and install the Python package: +* We use [Weights&Biases](https://wandb.ai) (or TensorBoard) to keep track of experiments. This is not a strict requirement — you can use our package without it — but it may make your life easier. If you want to use it, create an account [here](https://wandb.ai) and install the Python package: ``` pip install wandb @@ -130,6 +132,12 @@ pair_coeff * * deployed.pth a img.logo, +.wy-side-nav-search>a img.logo { + max-width: 90%; +} + +/* link colors in sidebar */ +.wy-menu-vertical a { + color: #d9d9d9; +} \ No newline at end of file diff --git a/docs/commandline/commands.rst b/docs/commandline/commands.rst index b58c87ab..f371dc2b 100644 --- a/docs/commandline/commands.rst +++ b/docs/commandline/commands.rst @@ -1,132 +1,2 @@ -Command-line Executables -======================== - -``nequip-train`` ----------------- - - .. code :: - - usage: nequip-train [-h] [--equivariance-test] [--model-debug-mode] [--grad-anomaly-mode] [--log LOG] config - -Train (or restart training of) a NequIP model. - -positional arguments: - config YAML file configuring the model, dataset, and other options - -optional arguments: - -h, --help show this help message and exit - --equivariance-test test the model's equivariance before training - --model-debug-mode enable model debug mode, which can sometimes give much more useful error messages at the - cost of some speed. Do not use for production training! - --grad-anomaly-mode enable PyTorch autograd anomaly mode to debug NaN gradients. Do not use for production - training! - --log LOG log file to store all the screen logging - -``nequip-evaluate`` -------------------- - - .. code :: - - usage: nequip-evaluate [-h] [--train-dir TRAIN_DIR] [--model MODEL] [--dataset-config DATASET_CONFIG] - [--metrics-config METRICS_CONFIG] [--test-indexes TEST_INDEXES] [--batch-size BATCH_SIZE] - [--device DEVICE] [--output OUTPUT] [--log LOG] - -Compute the error of a model on a test set using various metrics. The model, metrics, dataset, etc. can specified -in individual YAML config files, or a training session can be indicated with ``--train-dir``. In order of priority, -the global settings (dtype, TensorFloat32, etc.) are taken from: (1) the model config (for a training session), (2) -the dataset config (for a deployed model), or (3) the defaults. Prints only the final result in ``name = num`` format -to stdout; all other information is ``logging.debug``ed to stderr. WARNING: Please note that results of CUDA models -are rarely exactly reproducible, and that even CPU models can be nondeterministic. - -optional arguments: - -h, --help show this help message and exit - --train-dir TRAIN_DIR - Path to a working directory from a training session. - --model MODEL A deployed or pickled NequIP model to load. If omitted, defaults to `best_model.pth` in - `train_dir`. - --dataset-config DATASET_CONFIG - A YAML config file specifying the dataset to load test data from. If omitted, `config.yaml` - in `train_dir` will be used - --metrics-config METRICS_CONFIG - A YAML config file specifying the metrics to compute. If omitted, `config.yaml` in - `train_dir` will be used. If the config does not specify `metrics_components`, the default - is to logging.debug MAEs and RMSEs for all fields given in the loss function. If the - literal string `None`, no metrics will be computed. - --test-indexes TEST_INDEXES - Path to a file containing the indexes in the dataset that make up the test set. If omitted, - all data frames *not* used as training or validation data in the training session - `train_dir` will be used. - --batch-size BATCH_SIZE - Batch size to use. Larger is usually faster on GPU. - --device DEVICE Device to run the model on. If not provided, defaults to CUDA if available and CPU - otherwise. - --output OUTPUT XYZ file to write out the test set and model predicted forces, energies, etc. to. - --log LOG log file to store all the metrics and screen logging.debug - -``nequip-deploy`` ------------------ - - .. code :: - - usage: nequip-deploy [-h] {info,build} ... - -Deploy and view information about previously deployed NequIP models. - -optional arguments: - -h, --help show this help message and exit - -commands: - {info,build} - info Get information from a deployed model file - build Build a deployment model - -``nequip-deploy info`` -~~~~~~~~~~~~~~~~~~~~~~ - - .. code :: - - usage: nequip-deploy info [-h] model_path - -positional arguments: - model_path Path to a deployed model file. - -optional arguments: - -h, --help show this help message and exit - - -``nequip-deploy build`` -~~~~~~~~~~~~~~~~~~~~~~~ - - .. code :: - - usage: nequip-deploy build [-h] train_dir out_file - -positional arguments: - train_dir Path to a working directory from a training session. - out_file Output file for deployed model. - -optional arguments: - -h, --help show this help message and exit - - -``nequip-benchmark`` --------------------- - - .. code :: - - usage: nequip-benchmark [-h] [--profile PROFILE] [--device DEVICE] [-n N] [--n-data N_DATA] [--timestep TIMESTEP] - config - -Benchmark the approximate MD performance of a given model configuration / dataset pair. - -positional arguments: - config configuration file - -optional arguments: - -h, --help show this help message and exit - --profile PROFILE Profile instead of timing, creating and outputing a Chrome trace JSON to the given path. - --device DEVICE Device to run the model on. If not provided, defaults to CUDA if available and CPU - otherwise. - -n N Number of trials. - --n-data N_DATA Number of frames to use. - --timestep TIMESTEP MD timestep for ns/day esimation, in fs. Defauts to 1fs. +Command-line tools +================== diff --git a/docs/conf.py b/docs/conf.py index 11a5afca..808e052e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,7 +27,13 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_rtd_theme"] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx_rtd_theme", + "myst_parser", +] +source_suffix = [".rst", ".md"] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -49,3 +55,12 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +html_favicon = "favicon.png" +html_logo = "logo.png" +html_theme_options = { + "logo_only": True, +} + + +def setup(app): + app.add_css_file("custom.css") diff --git a/docs/errors/errors.md b/docs/errors/errors.md new file mode 100644 index 00000000..a4422fa0 --- /dev/null +++ b/docs/errors/errors.md @@ -0,0 +1,22 @@ +Common errors and warnings +========================== + +#### Unused keys + + - ```txt + KeyError: 'The following keys in the config file were not used, did you make a typo?: optimizer_params. + ``` + Since >=0.6.0, using `prefix_params` style subdictionaries of options is no longer supported. Only `_kwargs` is supported, i.e. `optimizer_kwargs`. Please update your YAML configs. + +#### Out-of-memory errors + + - ...with `nequip-evaluate` + + Choose a lower ``--batch-size``; while the highest value that fits in your GPU memory is good for performance, + lowering this does *not* affect the final results (beyond numerics). + +#### Other + + - Various shape errors + + Check the sanity of the shapes in your dataset. \ No newline at end of file diff --git a/docs/errors/errors.rst b/docs/errors/errors.rst deleted file mode 100644 index 576e553d..00000000 --- a/docs/errors/errors.rst +++ /dev/null @@ -1,12 +0,0 @@ -Errors -====== - -Common errors -------------- - -Various shape errors - Check the sanity of the shapes in your dataset. - -Out-of-memory errors with `nequip-evaluate` - Choose a lower ``--batch-size``; while the highest value that fits in your GPU memory is good for performance, - lowering this does *not* affect the final results (beyond numerics). diff --git a/docs/faq/FAQ.md b/docs/faq/FAQ.md new file mode 100644 index 00000000..bd00a4c2 --- /dev/null +++ b/docs/faq/FAQ.md @@ -0,0 +1,8 @@ +# FAQ + +## Loss functions + + - Despite changing the coefficients in `loss_coeffs`, the magnitude of my training loss isn't changing! + + Inidividual loss terms like `training_loss_f`, `training_loss_e`, etc. are reported **before** they are scaled by their coefficients for summing into the total loss. + diff --git a/docs/faq/FAQ.rst b/docs/faq/FAQ.rst deleted file mode 100644 index 411e77c1..00000000 --- a/docs/faq/FAQ.rst +++ /dev/null @@ -1,14 +0,0 @@ -FAQ -=== - -How do I... ------------ - -... continue to train a model that reached a stopping condition? - There will be an answer here. - -1. Reload the model trained with version 0.3.3 to the code in 0.4. - check out the migration note at :ref:`migration_note`. - -2. Specify my dataset for `nequip-train` and `nequip-eval`, see :ref:`_dataset_note`. - diff --git a/docs/favicon.png b/docs/favicon.png new file mode 100644 index 00000000..f66789ee Binary files /dev/null and b/docs/favicon.png differ diff --git a/docs/howto/conventions.md b/docs/howto/conventions.md new file mode 100644 index 00000000..3964fef2 --- /dev/null +++ b/docs/howto/conventions.md @@ -0,0 +1,29 @@ +# Conventions and units + +## Conventions + - Cells vectors are given in ASE style as the **rows** of the cell matrix + - The first index in an edge tuple (``edge_index[0]``) is the center atom, and the second (``edge_index[1]``) is the neighbor + +## Units + +`nequip` has no prefered system of units; models, errors, predictions, etc. will always be in the units of the original dataset used. + +```{warning} +`nequip` cannot and does not check the consistency of units in inputs you provide, and it is your responsibility to ensure consistent treatment of input and output units +``` + +Losses (`training_loss_f`, `validation_loss_e`, etc.) do **not** have physical units. Errors (`training_f_rmse`, `validation_f_rmse`) are always reported in physical units. + +## Pressure / stress / virials + +`nequip` always expresses stress in the "consistent" units of `energy / length^3`, which are **not** the typical physical units used by many codes for stress. + +```{warning} +Training labels for stress in the original dataset must be pre-processed by the user to be in consistent units. +``` + +Stress also includes an arbitrary sign convention, for which we adopt the choice that `virial = -stress x volume <=> stress = (-1/volume) * virial`. + +```{warning} +Training labels for stress in the original dataset must be pre-processed by the user to be in **this sign convention**, which they may or may not already be depending on their origin. +``` \ No newline at end of file diff --git a/docs/howto/conventions.rst b/docs/howto/conventions.rst deleted file mode 100644 index f4679a76..00000000 --- a/docs/howto/conventions.rst +++ /dev/null @@ -1,5 +0,0 @@ -Conventions -=========== - - - Cells vectors are given in ASE style as the **rows** of the cell matrix - - The first index in an edge tuple (``edge_index[0]``) is the center atom, and the second (``edge_index[1]``) is the neighbor \ No newline at end of file diff --git a/docs/howto/dataset.rst b/docs/howto/dataset.rst index 2b5267e7..7c535d47 100644 --- a/docs/howto/dataset.rst +++ b/docs/howto/dataset.rst @@ -25,14 +25,6 @@ NequIP will not automatically update the cached data. Key concepts ------------ -fixed_fields -~~~~~~~~~~~~ -Fixed fields are the quantities that are shared among all the configurations in the dataset. -For example, if the dataset is a trajectory of an NVT MD simulation, the super cell size and the atomic species -are indeed a constant matrix/vector through out the whole dataset. -In this case, in stead of repeating the same values for many times, -we specify the cell and species as fixed fields and only provide them once. - yaml interface ~~~~~~~~~~~~~~ ``nequip-train`` and ``nequip-evaluate`` automatically construct the AtomicDataset based on the yaml arguments. @@ -45,7 +37,7 @@ For example, ``dataset_file_name`` is used for training data and ``validation_da Python interface ~~~~~~~~~~~~~~~~ -See ``nequip.data.dataset.AtomicInMemoryDataset``. +See ``nequip.data.AtomicInMemoryDataset``. Prepare dataset and specify in yaml config ------------------------------------------ @@ -108,6 +100,12 @@ In the npz file, all the values should have the same row as the number of the co For example, the force array of 36 atomic configurations of an N-atom system should have the shape of (36, N, 3); their total_energy array should have the shape of (36). +NPZ also supports "fixed fields." Fixed fields are the quantities that are shared among all the configurations in the dataset. +For example, if the dataset is a trajectory of an NVT MD simulation, the super cell size and the atomic species +are indeed a constant matrix/vector through out the whole dataset. +In this case, in stead of repeating the same values for many times, +we specify the cell and species as fixed fields and only provide them once. + Below is an example of the yaml specification. .. code:: yaml diff --git a/docs/howto/howto.rst b/docs/howto/howto.rst index 07e84e84..eb376f54 100644 --- a/docs/howto/howto.rst +++ b/docs/howto/howto.rst @@ -3,5 +3,5 @@ How-to Tutorials .. toctree:: + conventions dataset - migrate diff --git a/docs/index.rst b/docs/index.rst index d2edd1a6..0bd1922b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,18 +14,13 @@ NequIP is an open-source package for creating, training, and using E(3)-equivari introduction/intro cite - installation/install - yaml/yaml howto/howto faq/FAQ commandline/commands - lammps/all - options/options + integrations/all api/nequip errors/errors - - Indices and tables ================== diff --git a/docs/installation/install.rst b/docs/installation/install.rst deleted file mode 100644 index 3e946815..00000000 --- a/docs/installation/install.rst +++ /dev/null @@ -1,39 +0,0 @@ -Installation -============ - -NequIP requires: - - * Python >= 3.6 - * PyTorch >= 1.8, <=1.11.*. PyTorch can be installed following the `instructions from their documentation `_. Note that neither ``torchvision`` nor ``torchaudio``, included in the default install command, are needed for NequIP. - -To install: - - * We use `Weights&Biases `_ to keep track of experiments. This is not a strict requirement — you can use our package without it — but it may make your life easier. If you want to use it, create an account `here `_ and install the Python package:: - - pip install wandb - - * Install the latest stable NequIP:: - - pip install https://github.com/mir-group/nequip/archive/main.zip - -To install previous versions of NequIP, please clone the repository from GitHub and check out the appropriate tag (for example ``v0.3.3`` for version 0.3.3). - -To install the current **unstable** development version of NequIP, please clone our repository and check out the ``develop`` branch. - -Installation Issues -------------------- - -The easiest way to check if your installation is working is to train a _toy_ model:: - - nequip-train configs/minimal.yaml - -If you suspect something is wrong, encounter errors, or just want to confirm that everything is in working order, you can also run the unit tests:: - - pip install pytest - pytest tests/unit/ - -To run the full tests, including a set of longer/more intensive integration tests, run:: - - pytest tests/ - -If a GPU is present, the unit tests will use it. \ No newline at end of file diff --git a/docs/lammps/all.rst b/docs/integrations/all.rst similarity index 100% rename from docs/lammps/all.rst rename to docs/integrations/all.rst diff --git a/docs/lammps/ase.rst b/docs/integrations/ase.rst similarity index 100% rename from docs/lammps/ase.rst rename to docs/integrations/ase.rst diff --git a/docs/lammps/lammps.rst b/docs/integrations/lammps.rst similarity index 100% rename from docs/lammps/lammps.rst rename to docs/integrations/lammps.rst diff --git a/docs/introduction/intro.md b/docs/introduction/intro.md new file mode 100644 index 00000000..acdf9ada --- /dev/null +++ b/docs/introduction/intro.md @@ -0,0 +1,5 @@ +# Overview + +## Installation + +See [`README.md`](https://github.com/mir-group/nequip/) diff --git a/docs/introduction/intro.rst b/docs/introduction/intro.rst deleted file mode 100644 index e0dcc32c..00000000 --- a/docs/introduction/intro.rst +++ /dev/null @@ -1,4 +0,0 @@ -Overview -======== - -TODO diff --git a/docs/logo.png b/docs/logo.png new file mode 100644 index 00000000..deb4ee3b Binary files /dev/null and b/docs/logo.png differ diff --git a/docs/options/HOWTO.md b/docs/options/HOWTO.md deleted file mode 100644 index 44bc5508..00000000 --- a/docs/options/HOWTO.md +++ /dev/null @@ -1,32 +0,0 @@ -Add this code to `auto_init.py`: - -```python -f = open("auto_all_options.rst", "w") - - -def print_option(builder, file): - print(f"!! {builder.__name__}", file=f) - if inspect.isclass(builder): - builder = builder.__init__ - sig = inspect.signature(builder) - for k, v in sig.parameters.items(): - if k == "self": - continue - print(k, file=f) - print(len(k) * "^", file=f) - if v.default == inspect.Parameter.empty: - print(f" | Type:", file=f) - print( - f" | Default: n/a\n", - file=f, - ) - else: - typestr = type(v.default).__name__ - print(f" | Type: {typestr}", file=f) - print( - f" | Default: ``{str(v.default)}``\n", - file=f, - ) -``` - -and call the function in every `instantiate`. \ No newline at end of file diff --git a/docs/options/dataset.rst b/docs/options/dataset.rst deleted file mode 100644 index f3ca194c..00000000 --- a/docs/options/dataset.rst +++ /dev/null @@ -1,78 +0,0 @@ -Dataset -======= - -Basic ------ - -r_max -^^^^^ - See :ref:`r_max_option`. - -type_names -^^^^^^^^^^ - | Type: NoneType - | Default: ``None`` - -chemical_symbols -^^^^^^^^^^^^^^^^ - | Type: NoneType - | Default: ``None`` - -chemical_symbol_to_type -^^^^^^^^^^^^^^^^^^^^^^^ - | Type: NoneType - | Default: ``None`` - -avg_num_neighbors -^^^^^^^^^^^^^^^^^ - | Type: NoneType - | Default: ``None`` - -key_mapping -^^^^^^^^^^^ - | Type: dict - | Default: ``{'positions': 'pos', 'energy': 'total_energy', 'force': 'forces', 'forces': 'forces', 'Z': 'atomic_numbers', 'atomic_number': 'atomic_numbers'}`` - -include_keys -^^^^^^^^ - | Type: list - | Default: ``[]`` - -npz_fixed_field_keys -^^^^^^^^^^^^^^^^^^^^ - | Type: list - | Default: ``[]`` - -file_name -^^^^^^^^^ - | Type: NoneType - | Default: ``None`` - -url -^^^ - | Type: NoneType - | Default: ``None`` - -force_fixed_keys -^^^^^^^^^^^^^^^^ - | Type: list - | Default: ``[]`` - -extra_fixed_fields -^^^^^^^^^^^^^^^^^^ - | Type: dict - | Default: ``{}`` - -include_frames -^^^^^^^^^^^^^^ - | Type: NoneType - | Default: ``None`` - -ase_args -^^^^^^^^ - | Type: dict - | Default: ``{}`` - -Advanced --------- -See tutorial on :ref:`../guide/_dataset_note`. diff --git a/docs/options/general.rst b/docs/options/general.rst deleted file mode 100644 index 1b75b6d9..00000000 --- a/docs/options/general.rst +++ /dev/null @@ -1,28 +0,0 @@ -General -======= - -Basic ------ - -root -^^^^ - | Type: - | Default: n/a - -run_name -^^^^^^^^ - | Type: path - | Default: n/a - - ``run_name`` specifies something about whatever - -Advanced --------- - -allow_tf32 -^^^^^^^^^^ - | Type: bool - | Default: ``False`` - - If ``False``, the use of NVIDIA's TensorFloat32 on Tensor Cores (Ampere architecture and later) will be disabled. - If ``True``, the PyTorch defaults (use anywhere possible) will remain. \ No newline at end of file diff --git a/docs/options/logging.rst b/docs/options/logging.rst deleted file mode 100644 index 675cdc45..00000000 --- a/docs/options/logging.rst +++ /dev/null @@ -1,8 +0,0 @@ -Logging -======= - -Basic ------ - -Advanced --------- \ No newline at end of file diff --git a/docs/options/model.rst b/docs/options/model.rst deleted file mode 100644 index a9ecb694..00000000 --- a/docs/options/model.rst +++ /dev/null @@ -1,149 +0,0 @@ -Model -===== - -Edge Basis -********** - -Basic ------ - -.. _r_max_option: - -r_max -^^^^^ - | Type: float - | Default: n/a - - The cutoff radius within which an atom is considered a neighbor. - -irreps_edge_sh -^^^^^^^^^^^^^^ - | Type: :ref:`Irreps` or int - | Default: n/a - - The irreps to use for the spherical harmonic projection of the edges. - If an integer, specifies all spherical harmonics up to and including that integer as :math:`\ell_{\text{max}}`. - If provided as explicit irreps, all multiplicities should be 1. - -num_basis -^^^^^^^^^ - | Type: int - | Default: ``8`` - - The number of radial basis functions to use. - -chemical_embedding_irreps_out -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Type: :ref:`Irreps` - | Default: n/a - - The size of the linear embedding of the chemistry of an atom. - -Advanced --------- - -BesselBasis_trainable -^^^^^^^^^^^^^^^^^^^^^ - | Type: bool - | Default: ``True`` - - Whether the Bessel radial basis should be trainable. - -basis -^^^^^ - | Type: type - | Default: ```` - - The radial basis to use. - -Convolution -*********** - -Basic ------ - -num_layers -^^^^^^^^^^ - | Type: int - | Default: ``3`` - - The number of convolution layers. - - -feature_irreps_hidden -^^^^^^^^^^^^^^^^^^^^^ - | Type: :ref:`Irreps` - | Default: n/a - - Specifies the irreps and multiplicities of the hidden features. - Typically, include irreps with all :math:`\ell` values up to :math:`\ell_{\text{max}}` (see `irreps_edge_sh`_), each with both even and odd parity. - For example, for ``irreps_edge_sh: 1``, one might provide: ``feature_irreps_hidden: 16x0e + 16x0o + 16x1e + 16x1o``. - -Advanced --------- - -invariant_layers -^^^^^^^^^^^^^^^^ - | Type: int - | Default: ``1`` - - The number of hidden layers in the radial neural network. - -invariant_neurons -^^^^^^^^^^^^^^^^^ - | Type: int - | Default: ``8`` - - The width of the hidden layers of the radial neural network. - -resnet -^^^^^^ - | Type: bool - | Default: ``False`` - -nonlinearity_type -^^^^^^^^^^^^^^^^^ - | Type: str - | Default: ``gate`` - -nonlinearity_scalars -^^^^^^^^^^^^^^^^^^^^ - | Type: dict - | Default: ``{'e': 'ssp', 'o': 'tanh'}`` - -nonlinearity_gates -^^^^^^^^^^^^^^^^^^ - | Type: dict - | Default: ``{'e': 'ssp', 'o': 'abs'}`` - -use_sc -^^^^^^ - | Type: bool - | Default: ``True`` - -Output block -************ - -Basic ------ - -conv_to_output_hidden_irreps_out -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Type: :ref:`Irreps` - | Default: n/a - - The middle (hidden) irreps of the output block. Should only contain irreps that are contained in the output of the network (``0e`` for potentials). - -Advanced --------- - - - - - - - - - - - diff --git a/docs/options/options.rst b/docs/options/options.rst deleted file mode 100644 index 95ab66ea..00000000 --- a/docs/options/options.rst +++ /dev/null @@ -1,10 +0,0 @@ -All Options -=========== - - .. toctree:: - - general - dataset - model - training - logging diff --git a/docs/options/training.rst b/docs/options/training.rst deleted file mode 100644 index b8c1711b..00000000 --- a/docs/options/training.rst +++ /dev/null @@ -1,8 +0,0 @@ -Training -======== - -Basic ------ - -Advanced --------- \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..a36b74ed --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +myst-parser +sphinx_rtd_theme diff --git a/docs/yaml/yaml.rst b/docs/yaml/yaml.rst deleted file mode 100644 index fd804436..00000000 --- a/docs/yaml/yaml.rst +++ /dev/null @@ -1,4 +0,0 @@ -YAML input -========== - -TODO diff --git a/examples/gmm_script.py b/examples/gmm_script.py new file mode 100644 index 00000000..239c6e99 --- /dev/null +++ b/examples/gmm_script.py @@ -0,0 +1,86 @@ +"""Example script to plot GMM uncertainties vs. atomic force errors from the results of `nequip-evaluate` + +To obtain GMM uncertainties for each atom in a system, a NequIP model must be trained +(e.g., using `nequip-train configs/minimal.yaml`) and then deployed. To fit a GMM +during deployment, run + + nequip-deploy build --using-dataset --model deployment.yaml deployed_model.pth + +where deployment.yaml is a config file that adds and fits a GMM to the deployed model +(for an example, see configs/minimal_gmm.yaml). Lastly, to obtain negative log +likelihoods (NLLs) on some test data, the NequIP model must be evaluated on a data set using +`nequip-evaluate` with `--output-fields node_features_nll` and +`--output-fields-from-original-dataset forces`. For example, running + + nequip-evaluate --dataset-config path/to/dataset-config.yaml --model deployed_model.pth --output out.xyz --output-fields node_features_nll --output-fields-from-original-dataset forces + +will evaluate deployed_model.pth (which includes the fitted GMM) on the data set in the config at +path/to/dataset-config.yaml and will write the NLLs and the true atomic forces (along +with the typical outputs of `nequip-evaluate`) to out.xyz. + +IMPORTANT: The data set config must contain the lines + + node_fields: + - node_features_nll + +in order for nequip-evaluate to recognize `node_features_nll` as a legitimate output. + +This script can then use out.xyz to create a plot of NLL vs. atomic force RMSE: + + python gmm_script.py out.xyz --output plot.png +""" + +import argparse + +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt + +from ase.io import read + +# Parse arguments +parser = argparse.ArgumentParser( + description="Make a plot of GMM NLL uncertainty vs. atomic force RMSE from the results of `nequip-evaluate`." +) +parser.add_argument( + "xyzoutput", + help=".xyz file from running `nequip-evaluate ... --output out.xyz --output-fields node_features_nll --output-fields-from-original-dataset forces", +) +parser.add_argument("--output", help="File to write plot to", default=None) +args = parser.parse_args() + +pred_forces = [] +true_forces = [] +nlls = [] + +# Extract predicted forces, true forces, and per-atom NLLs from evaluation +for frame in read(args.xyzoutput, index=":", format="extxyz"): + pred_forces.append(frame.get_forces()) + true_forces.append(frame.get_array("original_dataset_forces")) + nlls.append(frame.get_array("node_features_nll")) +pred_forces = np.concatenate(pred_forces, axis=0) +true_forces = np.concatenate(true_forces, axis=0) +nlls = np.concatenate(nlls, axis=0) + +# Compute force RMSE for each atom +force_rmses = np.sqrt(np.mean(np.square(true_forces - pred_forces), axis=-1)) + +# Plot per-atom NLL vs. per-atom force RMSE +f = plt.figure(figsize=(6, 6)) +plt.hist2d( + force_rmses, + nlls, + bins=(100, 100), + cmap="viridis", + norm=mpl.colors.LogNorm(), + cmin=1, +) +plt.title("NLL vs. Atomic Force RMSE") +plt.xlabel("Per-atom Force RMSE [force units]") +plt.ylabel("Per-atom Negative Log Likelihood (NLL) [unitless]") +plt.grid(linestyle="--") +plt.tight_layout() +if args.output is None: + plt.show() +else: + plt.savefig(args.output) diff --git a/examples/lj/README.md b/examples/lj/README.md index 424cbfb9..1483f2bb 100644 --- a/examples/lj/README.md +++ b/examples/lj/README.md @@ -1,3 +1,8 @@ +Lennard-Jones Custom Module Example +=================================== + +Note: for production simulations, a more appropriate Lennard-Jones energy term is provided in `nequip.model.PairPotentialTerm` / `nequip.model.PairPotential`. + Run commands with ``` PYTHONPATH=`pwd`:$PYTHONPATH nequip-* ... diff --git a/examples/parity_plot.py b/examples/parity_plot.py new file mode 100644 index 00000000..3e825c8a --- /dev/null +++ b/examples/parity_plot.py @@ -0,0 +1,60 @@ +"""Example script to make a parity plot from the results of `nequip-evaluate`. + +Thanks to Hongyu Yu for useful input: https://github.com/mir-group/nequip/discussions/223#discussioncomment-4923323 +""" + +import argparse +import numpy as np + +import matplotlib.pyplot as plt + +import ase.io + +# Parse arguments: +parser = argparse.ArgumentParser( + description="Make a parity plot from the results of `nequip-evaluate`." +) +parser.add_argument( + "xyzoutput", + help=".xyz file from running something like `nequip-evaluate ... --output out.xyz --output-fields-from-original-dataset total_energy,forces", +) +parser.add_argument("--output", help="File to write plot to", default=None) +args = parser.parse_args() + +forces = [] +true_forces = [] +energies = [] +true_energies = [] +for frame in ase.io.iread(args.xyzoutput): + forces.append(frame.get_forces().flatten()) + true_forces.append(frame.arrays["original_dataset_forces"].flatten()) + energies.append(frame.get_potential_energy()) + true_energies.append(frame.info["original_dataset_total_energy"]) +forces = np.concatenate(forces, axis=0) +true_forces = np.concatenate(true_forces, axis=0) +energies = np.asarray(energies) +true_energies = np.asarray(true_energies) + +fig, axs = plt.subplots(ncols=2, figsize=(8, 4)) + +ax = axs[0] +ax.set_xlabel("True force component") +ax.set_ylabel("Model force component") +ax.plot([0, 1], [0, 1], transform=ax.transAxes, linestyle="--", color="gray") +ax.scatter(true_forces, forces) +ax.set_aspect("equal") + +ax = axs[1] +ax.set_xlabel("True energy") +ax.set_ylabel("Model energy") +ax.plot([0, 1], [0, 1], transform=ax.transAxes, linestyle="--", color="gray") +ax.scatter(true_energies, energies) +ax.set_aspect("equal") + +plt.suptitle("Parity Plots") + +plt.tight_layout() +if args.output is None: + plt.show() +else: + plt.savefig(args.output) diff --git a/examples/plot_dimers.py b/examples/plot_dimers.py new file mode 100644 index 00000000..bafac7ac --- /dev/null +++ b/examples/plot_dimers.py @@ -0,0 +1,99 @@ +"""Plot energies of two-atom dimers from a NequIP model.""" + +import argparse +import itertools +from pathlib import Path + +from scipy.special import comb +import matplotlib.pyplot as plt + +import torch + +from nequip.data import AtomicData, AtomicDataDict +from nequip.scripts.evaluate import _load_deployed_or_traindir + +# Parse arguments: +parser = argparse.ArgumentParser( + description="Plot energies of two-atom dimers from a NequIP model" +) +parser.add_argument("model", help="Training dir or deployed model", type=Path) +parser.add_argument( + "--device", help="Device", default="cuda" if torch.cuda.is_available() else "cpu" +) +parser.add_argument("--output", help="File to write plot to", default=None) +parser.add_argument("--r-min", default=1.0, type=float) +parser.add_argument("--r-max", default=None, type=float) +parser.add_argument("--n-samples", default=500, type=int) +args = parser.parse_args() + +print("Loading model... ") +model, loaded_deployed_model, model_r_max, type_names = _load_deployed_or_traindir( + args.model, device=args.device +) +print(f" loaded{' deployed' if loaded_deployed_model else ''} model") +num_types = len(type_names) + +if args.r_max is not None: + model_r_max = args.r_max + +print("Computing dimers...") +potential = {} +N_sample = args.n_samples +N_combs = len(list(itertools.combinations_with_replacement(range(num_types), 2))) +r = torch.zeros(N_sample * N_combs, 2, 3, device=args.device) +rs_one = torch.linspace(args.r_min, model_r_max, 500, device=args.device) +rs = rs_one.repeat([N_combs]) +assert rs.shape == (N_combs * N_sample,) +r[:, 1, 0] += rs # offset second atom along x axis +types = torch.as_tensor( + [list(e) for e in itertools.combinations_with_replacement(range(num_types), 2)] +) +types = types.reshape(N_combs, 1, 2).expand(N_combs, N_sample, 2).reshape(-1) +r = r.reshape(-1, 3) +assert types.shape == r.shape[:1] +N_at_total = N_sample * N_combs * 2 +assert len(types) == N_at_total +edge_index = torch.vstack( + ( + torch.arange(N_at_total, device=args.device, dtype=torch.long), + torch.arange(1, N_at_total + 1, device=args.device, dtype=torch.long) + % N_at_total, + ) +) +data = AtomicData(pos=r, atom_types=types, edge_index=edge_index) +data.batch = torch.arange(N_sample * N_combs, device=args.device).repeat_interleave(2) +data.ptr = torch.arange(0, 2 * N_sample * N_combs + 1, 2, device=args.device) +result = model(AtomicData.to_AtomicDataDict(data.to(device=args.device))) + +print("Plotting...") +energies = ( + result[AtomicDataDict.TOTAL_ENERGY_KEY] + .reshape(N_combs, N_sample) + .cpu() + .detach() + .numpy() +) +del result +rs_one = rs_one.cpu().numpy() +nrows = int(comb(N=num_types, k=2, repetition=True)) +fig, axs = plt.subplots( + nrows=nrows, + sharex=True, + figsize=(6, 2 * nrows), + dpi=120, +) + +for i, (type1, type2) in enumerate( + itertools.combinations_with_replacement(range(num_types), 2) +): + ax = axs[i] + ax.set_ylabel(f"{type_names[type1]}-{type_names[type2]}") + ax.plot(rs_one, energies[i]) + +ax.set_xlabel("Distance") +plt.suptitle("$E_\\mathrm{total}$ for two-atom pairs") +plt.tight_layout() +if args.output is None: + plt.show() +else: + plt.savefig(args.output) diff --git a/examples/rdf.py b/examples/rdf.py new file mode 100644 index 00000000..c44c9c71 --- /dev/null +++ b/examples/rdf.py @@ -0,0 +1,55 @@ +"""Example of loading a NequIP dataset and computing its RDFs""" + +import argparse +import itertools + +from scipy.special import comb +import matplotlib.pyplot as plt + +from nequip.utils import Config +from nequip.data import dataset_from_config +from nequip.scripts.train import default_config +from nequip.utils._global_options import _set_global_options + +# Parse arguments: +parser = argparse.ArgumentParser( + description="Plot RDFs of dataset specified in a `nequip` YAML file" +) +parser.add_argument("config", help="YAML file configuring dataset") +parser.add_argument("--output", help="File to write plot to", default=None) +args = parser.parse_args() +config = Config.from_file(args.config, defaults=default_config) +_set_global_options(config) + +print("Loading dataset...") +r_max = config["r_max"] +dataset = dataset_from_config(config=config) +print( + f" loaded dataset of {len(dataset)} frames with {dataset.type_mapper.num_types} types" +) + +print("Computing RDFs...") +rdfs = dataset.rdf(bin_width=0.01) + +print("Plotting...") +num_types: int = dataset.type_mapper.num_types +fig, axs = plt.subplots(nrows=int(comb(N=num_types, k=2, repetition=True)), sharex=True) + +for i, (type1, type2) in enumerate( + itertools.combinations_with_replacement(range(num_types), 2) +): + ax = axs[i] + ax.set_ylabel( + f"{dataset.type_mapper.type_names[type1]}-{dataset.type_mapper.type_names[type2]}" + ) + hist, bin_edges = rdfs[(type1, type2)] + ax.plot(bin_edges[:-1], hist) + +ax.set_xlabel("Distance") +plt.suptitle("RDF") + +plt.tight_layout() +if args.output is None: + plt.show() +else: + plt.savefig(args.output) diff --git a/nequip/_version.py b/nequip/_version.py index b02164d2..8e22989a 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.5.6" +__version__ = "0.6.0" diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 728c260b..70c8fd2e 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -5,8 +5,9 @@ import warnings from copy import deepcopy -from typing import Union, Tuple, Dict, Optional, List, Set, Sequence +from typing import Union, Tuple, Dict, Optional, List, Set, Sequence, Final from collections.abc import Mapping +import os import numpy as np import ase.neighborlist @@ -49,6 +50,8 @@ AtomicDataDict.EDGE_ATTRS_KEY, AtomicDataDict.EDGE_EMBEDDING_KEY, AtomicDataDict.EDGE_FEATURES_KEY, + AtomicDataDict.EDGE_CUTOFF_KEY, + AtomicDataDict.EDGE_ENERGY_KEY, } _DEFAULT_GRAPH_FIELDS: Set[str] = { AtomicDataDict.TOTAL_ENERGY_KEY, @@ -56,6 +59,7 @@ AtomicDataDict.VIRIAL_KEY, AtomicDataDict.PBC_KEY, AtomicDataDict.CELL_KEY, + AtomicDataDict.BATCH_PTR_KEY, } _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) @@ -78,6 +82,7 @@ def register_fields( node_fields: set = set(node_fields) edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) + long_fields: set = set(long_fields) allfields = node_fields.union(edge_fields, graph_fields) assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) _NODE_FIELDS.update(node_fields) @@ -109,6 +114,17 @@ def deregister_fields(*fields: Sequence[str]) -> None: _GRAPH_FIELDS.discard(f) +def _register_field_prefix(prefix: str) -> None: + """Re-register all registered fields as the same type, but with `prefix` added on.""" + assert prefix.endswith("_") + register_fields( + node_fields=[prefix + e for e in _NODE_FIELDS], + edge_fields=[prefix + e for e in _EDGE_FIELDS], + graph_fields=[prefix + e for e in _GRAPH_FIELDS], + long_fields=[prefix + e for e in _LONG_FIELDS], + ) + + def _process_dict(kwargs, ignore_fields=[]): """Convert a dict of data into correct dtypes/shapes according to key""" # Deal with _some_ dtype issues @@ -141,6 +157,13 @@ def _process_dict(kwargs, ignore_fields=[]): # ^ this tensor is a scalar; we need to give it # a data dimension to play nice with irreps kwargs[k] = v + elif isinstance(v, torch.Tensor): + # This is a tensor, so we just don't do anything except avoid the warning in the `else` + pass + else: + warnings.warn( + f"Value for field {k} was of unsupported type {type(v)} (value was {v})" + ) if AtomicDataDict.BATCH_KEY in kwargs: num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 @@ -213,7 +236,6 @@ class AtomicData(Data): def __init__( self, irreps: Dict[str, e3nn.o3.Irreps] = {}, _validate: bool = True, **kwargs ): - # empty init needed by get_example if len(kwargs) == 0 and len(irreps) == 0: super().__init__() @@ -404,7 +426,6 @@ def from_ase( ) if atoms.calc is not None: - if isinstance( atoms.calc, (SinglePointCalculator, SinglePointDFTCalculator) ): @@ -680,6 +701,18 @@ def without_nodes(self, which_nodes): return type(self)(**new_dict) +_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower() +assert _ERROR_ON_NO_EDGES in ("true", "false") +_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true" + +_NEQUIP_MATSCIPY_NL: Final[bool] = os.environ.get("NEQUIP_MATSCIPY_NL", "false").lower() +assert _NEQUIP_MATSCIPY_NL in ("true", "false") +_NEQUIP_MATSCIPY_NL = _NEQUIP_MATSCIPY_NL == "true" + +if _NEQUIP_MATSCIPY_NL: + import matscipy.neighbours + + def neighbor_list_and_relative_vec( pos, r_max, @@ -757,22 +790,32 @@ def neighbor_list_and_relative_vec( # ASE dependent part temp_cell = ase.geometry.complete_cell(temp_cell) - first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( - "ijS", - pbc, - temp_cell, - temp_pos, - cutoff=float(r_max), - self_interaction=strict_self_interaction, # we want edges from atom to itself in different periodic images! - use_scaled_positions=False, - ) + if _NEQUIP_MATSCIPY_NL: + assert strict_self_interaction and not self_interaction + first_idex, second_idex, shifts = matscipy.neighbours.neighbour_list( + "ijS", + pbc=pbc, + cell=temp_cell, + positions=temp_pos, + cutoff=float(r_max), + ) + else: + first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( + "ijS", + pbc, + temp_cell, + temp_pos, + cutoff=float(r_max), + self_interaction=strict_self_interaction, # we want edges from atom to itself in different periodic images! + use_scaled_positions=False, + ) # Eliminate true self-edges that don't cross periodic boundaries if not self_interaction: bad_edge = first_idex == second_idex bad_edge &= np.all(shifts == 0, axis=1) keep_edge = ~bad_edge - if not np.any(keep_edge): + if _ERROR_ON_NO_EDGES and (not np.any(keep_edge)): raise ValueError( f"Every single atom has no neighbors within the cutoff r_max={r_max} (after eliminating self edges, no edges remain in this system)" ) diff --git a/nequip/data/AtomicDataDict.py b/nequip/data/AtomicDataDict.py index 069f8cff..f7713e6f 100644 --- a/nequip/data/AtomicDataDict.py +++ b/nequip/data/AtomicDataDict.py @@ -111,4 +111,12 @@ def with_batch(data: Type) -> Type: pos = data[_keys.POSITIONS_KEY] batch = torch.zeros(len(pos), dtype=torch.long, device=pos.device) data[_keys.BATCH_KEY] = batch + # ugly way to make a tensor of [0, len(pos)], but it avoids transfers or casts + data[_keys.BATCH_PTR_KEY] = torch.arange( + start=0, + end=len(pos) + 1, + step=len(pos), + dtype=torch.long, + device=pos.device, + ) return data diff --git a/nequip/data/__init__.py b/nequip/data/__init__.py index 212cc5f6..02c41d55 100644 --- a/nequip/data/__init__.py +++ b/nequip/data/__init__.py @@ -3,12 +3,20 @@ PBC, register_fields, deregister_fields, + _register_field_prefix, _NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS, + _LONG_FIELDS, ) -from .dataset import AtomicDataset, AtomicInMemoryDataset, NpzDataset, ASEDataset -from .dataloader import DataLoader, Collater +from ._dataset import ( + AtomicDataset, + AtomicInMemoryDataset, + NpzDataset, + ASEDataset, + HDF5Dataset, +) +from .dataloader import DataLoader, Collater, PartialSampler from ._build import dataset_from_config from ._test_data import EMTTestDataset @@ -17,15 +25,19 @@ PBC, register_fields, deregister_fields, + _register_field_prefix, AtomicDataset, AtomicInMemoryDataset, NpzDataset, ASEDataset, + HDF5Dataset, DataLoader, Collater, + PartialSampler, dataset_from_config, _NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS, + _LONG_FIELDS, EMTTestDataset, ] diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 8757198f..35b59dba 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -57,10 +57,10 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: raise NameError(f"dataset type {dataset_name} does not exists") # if dataset r_max is not found, use the universal r_max - eff_key = "extra_fixed_fields" - prefixed_eff_key = f"{prefix}_{eff_key}" + atomicdata_options_key = "AtomicData_options" + prefixed_eff_key = f"{prefix}_{atomicdata_options_key}" config[prefixed_eff_key] = get_w_prefix( - eff_key, {}, prefix=prefix, arg_dicts=config + atomicdata_options_key, {}, prefix=prefix, arg_dicts=config ) config[prefixed_eff_key]["r_max"] = get_w_prefix( "r_max", diff --git a/nequip/data/_dataset/__init__.py b/nequip/data/_dataset/__init__.py new file mode 100644 index 00000000..9948e377 --- /dev/null +++ b/nequip/data/_dataset/__init__.py @@ -0,0 +1,6 @@ +from ._base_datasets import AtomicDataset, AtomicInMemoryDataset +from ._ase_dataset import ASEDataset +from ._npz_dataset import NpzDataset +from ._hdf5_dataset import HDF5Dataset + +__all__ = [ASEDataset, AtomicDataset, AtomicInMemoryDataset, NpzDataset, HDF5Dataset] diff --git a/nequip/data/_dataset/_ase_dataset.py b/nequip/data/_dataset/_ase_dataset.py new file mode 100644 index 00000000..3246d791 --- /dev/null +++ b/nequip/data/_dataset/_ase_dataset.py @@ -0,0 +1,238 @@ +import tempfile +import functools +import itertools +from os.path import dirname, basename, abspath +from typing import Dict, Any, List, Union, Optional, Sequence + +import ase +import ase.io + +import torch +import torch.multiprocessing as mp + + +from nequip.utils.multiprocessing import num_tasks +from .. import AtomicData +from ..transforms import TypeMapper +from ._base_datasets import AtomicInMemoryDataset + + +def _ase_dataset_reader( + rank: int, + world_size: int, + tmpdir: str, + ase_kwargs: dict, + atomicdata_kwargs: dict, + include_frames, + global_options: dict, +) -> Union[str, List[AtomicData]]: + """Parallel reader for all frames in file.""" + if world_size > 1: + from nequip.utils._global_options import _set_global_options + + # ^ avoid import loop + # we only `multiprocessing` if world_size > 1 + _set_global_options(global_options) + # interleave--- in theory it is better for performance for the ranks + # to read consecutive blocks, but the way ASE is written the whole + # file gets streamed through all ranks anyway, so just trust the OS + # to cache things sanely, which it will. + # ASE handles correctly the case where there are no frames in index + # and just gives an empty list, so that will succeed: + index = slice(rank, None, world_size) + if include_frames is None: + # count includes 0, 1, ..., inf + include_frames = itertools.count() + + datas = [] + # stream them from ase too using iread + for i, atoms in enumerate(ase.io.iread(**ase_kwargs, index=index, parallel=False)): + global_index = rank + (world_size * i) + datas.append( + ( + global_index, + AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) + if global_index in include_frames + # in-memory dataset will ignore this later, but needed for indexing to work out + else None, + ) + ) + # Save to a tempfile--- + # there can be a _lot_ of tensors here, and rather than dealing with + # the complications of running out of file descriptors and setting + # sharing methods, since this is a one time thing, just make it simple + # and avoid shared memory entirely. + if world_size > 1: + path = f"{tmpdir}/rank{rank}.pth" + torch.save(datas, path) + return path + else: + return datas + + +class ASEDataset(AtomicInMemoryDataset): + """ + + Args: + ase_args (dict): arguments for ase.io.read + include_keys (list): in addition to forces and energy, the keys that needs to + be parsed into dataset + The data stored in ase.atoms.Atoms.array has the lowest priority, + and it will be overrided by data in ase.atoms.Atoms.info + and ase.atoms.Atoms.calc.results. Optional + key_mapping (dict): rename some of the keys to the value str. Optional + + Example: Given an atomic data stored in "H2.extxyz" that looks like below: + + ```H2.extxyz + 2 + Properties=species:S:1:pos:R:3 energy=-10 user_label=2.0 pbc="F F F" + H 0.00000000 0.00000000 0.00000000 + H 0.00000000 0.00000000 1.02000000 + ``` + + The yaml input should be + + ``` + dataset: ase + dataset_file_name: H2.extxyz + ase_args: + format: extxyz + include_keys: + - user_label + key_mapping: + user_label: label0 + chemical_symbols: + - H + ``` + + for VASP parser, the yaml input should be + ``` + dataset: ase + dataset_file_name: OUTCAR + ase_args: + format: vasp-out + key_mapping: + free_energy: total_energy + chemical_symbols: + - H + ``` + + """ + + def __init__( + self, + root: str, + ase_args: dict = {}, + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + key_mapping: Optional[dict] = None, + include_keys: Optional[List[str]] = None, + ): + self.ase_args = {} + self.ase_args.update(getattr(type(self), "ASE_ARGS", dict())) + self.ase_args.update(ase_args) + assert "index" not in self.ase_args + assert "filename" not in self.ase_args + + self.include_keys = include_keys + self.key_mapping = key_mapping + + super().__init__( + file_name=file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + @classmethod + def from_atoms_list(cls, atoms: Sequence[ase.Atoms], **kwargs): + """Make an ``ASEDataset`` from a list of ``ase.Atoms`` objects. + + If `root` is not provided, a temporary directory will be used. + + Please note that this is a convinience method that does NOT avoid a round-trip to disk; the provided ``atoms`` will be written out to a file. + + Ignores ``kwargs["file_name"]`` if it is provided. + + Args: + atoms + **kwargs: passed through to the constructor + Returns: + The constructed ``ASEDataset``. + """ + if "root" not in kwargs: + tmpdir = tempfile.TemporaryDirectory() + kwargs["root"] = tmpdir.name + else: + tmpdir = None + kwargs["file_name"] = tmpdir.name + "/atoms.xyz" + atoms = list(atoms) + # Write them out + ase.io.write(kwargs["file_name"], atoms, format="extxyz") + # Read them in + obj = cls(**kwargs) + if tmpdir is not None: + # Make it keep a reference to the tmpdir to keep it alive + # When the dataset is garbage collected, the tmpdir will + # be too, and will (hopefully) get deleted eventually. + # Or at least by end of program... + obj._tmpdir_ref = tmpdir + return obj + + @property + def raw_file_names(self): + return [basename(self.file_name)] + + @property + def raw_dir(self): + return dirname(abspath(self.file_name)) + + def get_data(self): + ase_args = {"filename": self.raw_dir + "/" + self.raw_file_names[0]} + ase_args.update(self.ase_args) + + # skip the None arguments + kwargs = dict( + include_keys=self.include_keys, + key_mapping=self.key_mapping, + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + kwargs.update(self.AtomicData_options) + n_proc = num_tasks() + with tempfile.TemporaryDirectory() as tmpdir: + from nequip.utils._global_options import _get_latest_global_options + + # ^ avoid import loop + reader = functools.partial( + _ase_dataset_reader, + world_size=n_proc, + tmpdir=tmpdir, + ase_kwargs=ase_args, + atomicdata_kwargs=kwargs, + include_frames=self.include_frames, + # get the global options of the parent to initialize the worker correctly + global_options=_get_latest_global_options(), + ) + if n_proc > 1: + # things hang for some obscure OpenMP reason on some systems when using `fork` method + ctx = mp.get_context("forkserver") + with ctx.Pool(processes=n_proc) as p: + # map it over the `rank` argument + datas = p.map(reader, list(range(n_proc))) + # clean up the pool before loading the data + datas = [torch.load(d) for d in datas] + datas = sum(datas, []) + # un-interleave the datas + datas = sorted(datas, key=lambda e: e[0]) + else: + datas = reader(rank=0) + # datas here is already in order, stride 1 start 0 + # no need to un-interleave + # return list of AtomicData: + return [e[1] for e in datas] diff --git a/nequip/data/dataset.py b/nequip/data/_dataset/_base_datasets.py similarity index 57% rename from nequip/data/dataset.py rename to nequip/data/_dataset/_base_datasets.py index c38b8eae..bda86734 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/_dataset/_base_datasets.py @@ -1,19 +1,13 @@ import numpy as np import logging -import tempfile import inspect -import functools import itertools import yaml import hashlib -from os.path import dirname, basename, abspath -from typing import Tuple, Dict, Any, List, Callable, Union, Optional, Sequence - -import ase -import ase.io +import math +from typing import Tuple, Dict, Any, List, Callable, Union, Optional import torch -import torch.multiprocessing as mp from torch_runstats.scatter import scatter_std, scatter_mean @@ -31,22 +25,21 @@ from nequip.utils.batch_ops import bincount from nequip.utils.regressor import solver from nequip.utils.savenload import atomic_write -from nequip.utils.multiprocessing import num_tasks -from .transforms import TypeMapper -from .AtomicData import _process_dict +from ..transforms import TypeMapper class AtomicDataset(Dataset): """The base class for all NequIP datasets.""" - fixed_fields: Dict[str, Any] root: str + dtype: torch.dtype def __init__( self, root: str, type_mapper: Optional[TypeMapper] = None, ): + self.dtype = torch.get_default_dtype() super().__init__(root=root, transform=type_mapper) def statistics( @@ -80,7 +73,7 @@ def _get_parameters(self) -> Dict[str, Any]: if k not in IGNORE_KEYS and hasattr(self, k) } # Add other relevant metadata: - params["dtype"] = str(torch.get_default_dtype()) + params["dtype"] = str(self.dtype) params["nequip_version"] = nequip.__version__ return params @@ -117,8 +110,7 @@ class AtomicInMemoryDataset(AtomicDataset): root (str, optional): Root directory where the dataset should be saved. Defaults to current working directory. file_name (str, optional): file name of data source. only used in children class url (str, optional): url to download data source - force_fixed_keys (list, optional): keys to move from AtomicData to fixed_fields dictionary - extra_fixed_fields (dict, optional): extra key that are not stored in data but needed for AtomicData initialization + AtomicData_options (dict, optional): extra key that are not stored in data but needed for AtomicData initialization include_frames (list, optional): the frames to process with the constructor. type_mapper (TypeMapper): the transformation to map atomic information to species index. Optional """ @@ -128,8 +120,7 @@ def __init__( root: str, file_name: Optional[str] = None, url: Optional[str] = None, - force_fixed_keys: List[str] = [], - extra_fixed_fields: Dict[str, Any] = {}, + AtomicData_options: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, type_mapper: Optional[TypeMapper] = None, ): @@ -138,17 +129,12 @@ def __init__( self.file_name = ( getattr(type(self), "FILE_NAME", None) if file_name is None else file_name ) - force_fixed_keys = set(force_fixed_keys).union( - getattr(type(self), "FORCE_FIXED_KEYS", []) - ) self.url = getattr(type(self), "URL", url) - self.force_fixed_keys = force_fixed_keys - self.extra_fixed_fields = extra_fixed_fields + self.AtomicData_options = AtomicData_options self.include_frames = include_frames self.data = None - self.fixed_fields = None # !!! don't delete this block. # otherwise the inherent children class @@ -165,9 +151,7 @@ def __init__( # Then pre-process the data if disk files are not found super().__init__(root=root, type_mapper=type_mapper) if self.data is None: - self.data, self.fixed_fields, include_frames = torch.load( - self.processed_paths[0] - ) + self.data, include_frames = torch.load(self.processed_paths[0]) if not np.all(include_frames == self.include_frames): raise ValueError( f"the include_frames is changed. " @@ -195,11 +179,9 @@ def get_data( Note that parameters for graph construction such as ``pbc`` and ``r_max`` should be included here as (likely, but not necessarily, fixed) fields. Returns: - A two-tuple of: + A dict: fields: dict mapping a field name ('pos', 'cell') to a list-like sequence of tensor-like objects giving that field's value for each example. - fixed_fields: dict - mapping field names to their constant values for every example in the dataset. Or: data_list: List[AtomicData] """ @@ -216,51 +198,34 @@ def download(self): def process(self): data = self.get_data() - if len(data) == 1: + if isinstance(data, list): # It's a data list - data_list = data[0] - if not (self.include_frames is None or data[0] is None): + data_list = data + if not (self.include_frames is None or data_list is None): data_list = [data_list[i] for i in self.include_frames] assert all(isinstance(e, AtomicData) for e in data_list) assert all(AtomicDataDict.BATCH_KEY not in e for e in data_list) - fields, fixed_fields = {}, {} - - # take the force_fixed_keys away from the fields - for key in self.force_fixed_keys: - if key in data_list[0]: - fixed_fields[key] = data_list[0][key] - - fixed_fields.update(self.extra_fixed_fields) + fields = {} - elif len(data) == 2: - - # It's fields and fixed_fields + elif isinstance(data, dict): + # It's fields # Get our data - fields, fixed_fields = data - - fixed_fields.update(self.extra_fixed_fields) + fields = data # check keys - all_keys = set(fields.keys()).union(fixed_fields.keys()) - assert len(all_keys) == len(fields) + len( - fixed_fields - ), "No overlap in keys between data and fixed fields allowed!" + all_keys = set(fields.keys()) assert AtomicDataDict.BATCH_KEY not in all_keys # Check bad key combinations, but don't require that this be a graph yet. AtomicDataDict.validate_keys(all_keys, graph_required=False) - # take the force_fixed_keys away from the fields - for key in self.force_fixed_keys: - if key in fields: - fixed_fields[key] = fields.pop(key)[0] - - # check dimesionality + # check dimensionality num_examples = set([len(a) for a in fields.values()]) if not len(num_examples) == 1: + shape_dict = {f: v.shape for f, v in fields.items()} raise ValueError( - f"This dataset is invalid: expected all fields to have same length (same number of examples), but they had shapes { {f: v.shape for f, v in fields.items() } }" + f"This dataset is invalid: expected all fields to have same length (same number of examples), but they had shapes {shape_dict}" ) num_examples = next(iter(num_examples)) @@ -275,11 +240,16 @@ def process(self): else: # do neighborlist from points constructor = AtomicData.from_points - assert "r_max" in all_keys + assert "r_max" in self.AtomicData_options assert AtomicDataDict.POSITIONS_KEY in all_keys data_list = [ - constructor(**{**{f: v[i] for f, v in fields.items()}, **fixed_fields}) + constructor( + **{ + **{f: v[i] for f, v in fields.items()}, + **self.AtomicData_options, + } + ) for i in include_frames ] @@ -288,13 +258,10 @@ def process(self): # Batch it for efficient saving # This limits an AtomicInMemoryDataset to a maximum of LONG_MAX atoms _overall_, but that is a very big number and any dataset that large is probably not "InMemory" anyway - data = Batch.from_data_list(data_list, exclude_keys=fixed_fields.keys()) + data = Batch.from_data_list(data_list) del data_list del fields - # type conversion - _process_dict(fixed_fields, ignore_fields=["r_max"]) - total_MBs = sum(item.numel() * item.element_size() for _, item in data) / ( 1024 * 1024 ) @@ -310,21 +277,45 @@ def process(self): # datasets. It only matters that they don't simultaneously try # to write the _same_ file, corrupting it. with atomic_write(self.processed_paths[0], binary=True) as f: - torch.save((data, fixed_fields, self.include_frames), f) + torch.save((data, self.include_frames), f) with atomic_write(self.processed_paths[1], binary=False) as f: yaml.dump(self._get_parameters(), f) logging.info("Cached processed data to disk") self.data = data - self.fixed_fields = fixed_fields def get(self, idx): - out = self.data.get_example(idx) - # Add back fixed fields - for f, v in self.fixed_fields.items(): - out[f] = v - return out + return self.data.get_example(idx) + + def _selectors( + self, + stride: int = 1, + ): + if self._indices is not None: + graph_selector = torch.as_tensor(self._indices)[::stride] + # note that self._indices is _not_ necessarily in order, + # while self.data --- which we take our arrays from --- + # is always in the original order. + # In particular, the values of `self.data.batch` + # are indexes in the ORIGINAL order + # thus we need graph level properties to also be in the original order + # so that batch values index into them correctly + # since self.data.batch is always sorted & contiguous + # (because of Batch.from_data_list) + # we sort it: + graph_selector, _ = torch.sort(graph_selector) + else: + graph_selector = torch.arange(0, self.len(), stride) + + node_selector = torch.as_tensor( + np.in1d(self.data.batch.numpy(), graph_selector.numpy()) + ) + + edge_index = self.data[AtomicDataDict.EDGE_INDEX_KEY] + edge_selector = node_selector[edge_index[0]] & node_selector[edge_index[1]] + + return (graph_selector, node_selector, edge_selector) def statistics( self, @@ -374,45 +365,22 @@ def statistics( if len(fields) == 0: return [] - if self._indices is not None: - graph_selector = torch.as_tensor(self._indices)[::stride] - # note that self._indices is _not_ necessarily in order, - # while self.data --- which we take our arrays from --- - # is always in the original order. - # In particular, the values of `self.data.batch` - # are indexes in the ORIGINAL order - # thus we need graph level properties to also be in the original order - # so that batch values index into them correctly - # since self.data.batch is always sorted & contiguous - # (because of Batch.from_data_list) - # we sort it: - graph_selector, _ = torch.sort(graph_selector) - else: - graph_selector = torch.arange(0, self.len(), stride) - num_graphs = len(graph_selector) + graph_selector, node_selector, edge_selector = self._selectors(stride=stride) - node_selector = torch.as_tensor( - np.in1d(self.data.batch.numpy(), graph_selector.numpy()) - ) + num_graphs = len(graph_selector) num_nodes = node_selector.sum() - - edge_index = self.data[AtomicDataDict.EDGE_INDEX_KEY] - edge_selector = node_selector[edge_index[0]] & node_selector[edge_index[1]] num_edges = edge_selector.sum() - del edge_index if self.transform is not None: - # pre-transform the fixed fields and data so that statistics process transformed data - ff_transformed = self.transform(self.fixed_fields, types_required=False) + # pre-transform the data so that statistics process transformed data data_transformed = self.transform(self.data.to_dict(), types_required=False) else: - ff_transformed = self.fixed_fields data_transformed = self.data.to_dict() # pre-select arrays # this ensures that all following computations use the right data all_keys = set() selectors = {} - for k in list(ff_transformed.keys()) + list(data_transformed.keys()): + for k in data_transformed.keys(): all_keys.add(k) if k in _NODE_FIELDS: selectors[k] = node_selector @@ -425,9 +393,6 @@ def statistics( # TODO: do the batch indexes, edge_indexes, etc. after selection need to be # "compacted" to subtract out their offsets? For now, we just punt this # onto the writer of the callable field. - # do not actually select on fixed fields, since they are constant - # but still only select fields that are correctly registered - ff_transformed = {k: v for k, v in ff_transformed.items() if k in selectors} # apply selector to actual data data_transformed = { k: data_transformed[k][selectors[k]] @@ -441,9 +406,7 @@ def statistics( if callable(field): # make a joined thing? so it includes fixed fields arr, arr_is_per = field(data_transformed) - arr = arr.to( - torch.get_default_dtype() - ) # all statistics must be on floating + arr = arr.to(self.dtype) # all statistics must be on floating assert arr_is_per in ("node", "graph", "edge") else: if field not in all_keys: @@ -455,10 +418,7 @@ def statistics( raise RuntimeError( f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`" ) - if field in ff_transformed: - arr = ff_transformed[field] - else: - arr = data_transformed[field] + arr = data_transformed[field] if field in _NODE_FIELDS: arr_is_per = "node" elif field in _GRAPH_FIELDS: @@ -475,7 +435,7 @@ def statistics( ) if not isinstance(arr, torch.Tensor): if np.issubdtype(arr.dtype, np.floating): - arr = torch.as_tensor(arr, dtype=torch.get_default_dtype()) + arr = torch.as_tensor(arr, dtype=self.dtype) else: arr = torch.as_tensor(arr) if arr_is_per == "node": @@ -499,10 +459,17 @@ def statistics( elif ana_mode == "mean_std": # mean and std + if len(arr) < 2: + raise ValueError( + "Can't do per species standard deviation without at least two samples" + ) mean = torch.mean(arr, dim=0) std = torch.std(arr, dim=0, unbiased=unbiased) out.append((mean, std)) + elif ana_mode == "absmax": + out.append((arr.abs().max(),)) + elif ana_mode.startswith("per_species_"): # per-species algorithm_kwargs = kwargs.pop(field + ana_mode, {}) @@ -510,15 +477,7 @@ def statistics( ana_mode = ana_mode[len("per_species_") :] if atom_types is None: - if AtomicDataDict.ATOM_TYPE_KEY in data_transformed: - atom_types = data_transformed[AtomicDataDict.ATOM_TYPE_KEY] - elif AtomicDataDict.ATOM_TYPE_KEY in ff_transformed: - atom_types = ff_transformed[AtomicDataDict.ATOM_TYPE_KEY] - atom_types = ( - atom_types.unsqueeze(0) - .expand((num_graphs,) + atom_types.shape) - .reshape(-1) - ) + atom_types = data_transformed[AtomicDataDict.ATOM_TYPE_KEY] results = self._per_species_statistics( ana_mode, @@ -573,18 +532,24 @@ def _per_atom_statistics( arr = arr / N assert arr.shape == (len(N),) + data_dim if ana_mode == "mean_std": + if len(arr) < 2: + raise ValueError( + "Can't do standard deviation without at least two samples" + ) mean = torch.mean(arr, dim=0) std = torch.std(arr, unbiased=unbiased, dim=0) return mean, std elif ana_mode == "rms": return (torch.sqrt(torch.mean(arr.square())),) + elif ana_mode == "absmax": + return (torch.max(arr.abs()),) else: raise NotImplementedError( f"{ana_mode} for per-atom analysis is not implemented" ) - @staticmethod def _per_species_statistics( + self, ana_mode: str, arr: torch.Tensor, arr_is_per: str, @@ -610,14 +575,20 @@ def _per_species_statistics( f"{ana_mode} for per species analysis is not implemented for shape {arr.shape}" ) - N = N.type(torch.get_default_dtype()) + N = N.type(self.dtype) return solver(N, arr, **algorithm_kwargs) elif arr_is_per == "node": - arr = arr.type(torch.get_default_dtype()) + arr = arr.type(self.dtype) if ana_mode == "mean_std": + # There need to be at least two occurances of each atom type in the + # WHOLE dataset, not in any given frame: + if torch.any(N.sum(dim=0) < 2): + raise ValueError( + "Can't do per species standard deviation without at least two samples per species" + ) mean = scatter_mean(arr, atom_types, dim=0) assert mean.shape[1:] == arr.shape[1:] # [N, dims] -> [type, dims] assert len(mean) == N.shape[1] @@ -632,344 +603,62 @@ def _per_species_statistics( for i in range(dims): square = square.mean(axis=-1) return (torch.sqrt(square),) + else: + raise NotImplementedError( + f"Statistics mode {ana_mode} isn't yet implemented for per_species_" + ) else: raise NotImplementedError + def rdf( + self, bin_width: float, stride: int = 1 + ) -> Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]: + """Compute the pairwise RDFs of the dataset. -class NpzDataset(AtomicInMemoryDataset): - """Load data from an npz file. - - To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``include_keys``, - ``npz_fixed_fields_keys`` or ``extra_fixed_fields``. - - Args: - key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys. Optional - include_keys (list): the attributes to be processed and stored. Optional - npz_fixed_field_keys: the attributes that only have one instance but apply to all frames. Optional - Note that the mapped keys (as determined by the _values_ in ``key_mapping``) should be used in - ``npz_fixed_field_keys``, not the original npz keys from before mapping. If an npz key is not - present in ``key_mapping``, it is mapped to itself, and this point is not relevant. - - Example: Given a npz file with 10 configurations, each with 14 atoms. - - position: (10, 14, 3) - force: (10, 14, 3) - energy: (10,) - Z: (14) - user_label1: (10) # per config - user_label2: (10, 14, 3) # per atom - - The input yaml should be - - ```yaml - dataset: npz - dataset_file_name: example.npz - include_keys: - - user_label1 - - user_label2 - npz_fixed_field_keys: - - cell - - atomic_numbers - key_mapping: - position: pos - force: forces - energy: total_energy - Z: atomic_numbers - ``` - - """ - - def __init__( - self, - root: str, - key_mapping: Dict[str, str] = { - "positions": AtomicDataDict.POSITIONS_KEY, - "energy": AtomicDataDict.TOTAL_ENERGY_KEY, - "force": AtomicDataDict.FORCE_KEY, - "forces": AtomicDataDict.FORCE_KEY, - "Z": AtomicDataDict.ATOMIC_NUMBERS_KEY, - "atomic_number": AtomicDataDict.ATOMIC_NUMBERS_KEY, - }, - include_keys: List[str] = [], - npz_fixed_field_keys: List[str] = [], - file_name: Optional[str] = None, - url: Optional[str] = None, - force_fixed_keys: List[str] = [], - extra_fixed_fields: Dict[str, Any] = {}, - include_frames: Optional[List[int]] = None, - type_mapper: TypeMapper = None, - ): - self.key_mapping = key_mapping - self.npz_fixed_field_keys = npz_fixed_field_keys - self.include_keys = include_keys - - super().__init__( - file_name=file_name, - url=url, - root=root, - force_fixed_keys=force_fixed_keys, - extra_fixed_fields=extra_fixed_fields, - include_frames=include_frames, - type_mapper=type_mapper, - ) + Args: + bin_width: width of the histogram bin in distance units + stride: stride of data to include - @property - def raw_file_names(self): - return [basename(self.file_name)] + Returns: + dictionary mapping `(type1, type2)` to tuples of `(hist, bin_edges)` in the style of `np.histogram`. + """ + graph_selector, node_selector, edge_selector = self._selectors(stride=stride) - @property - def raw_dir(self): - return dirname(abspath(self.file_name)) + data = AtomicData.to_AtomicDataDict(self.data) + data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) - def get_data(self): + results = {} - data = np.load(self.raw_dir + "/" + self.raw_file_names[0], allow_pickle=True) + types = self.type_mapper(data)[AtomicDataDict.ATOM_TYPE_KEY] - # only the keys explicitly mentioned in the yaml file will be parsed - keys = set(list(self.key_mapping.keys())) - keys.update(self.npz_fixed_field_keys) - keys.update(self.include_keys) - keys.update(list(self.extra_fixed_fields.keys())) - keys = keys.intersection(set(list(data.keys()))) + edge_types = torch.index_select( + types, 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1) + ).view(2, -1) + types_center = edge_types[0].numpy() + types_neigh = edge_types[1].numpy() - mapped = {self.key_mapping.get(k, k): data[k] for k in keys} + r_max: float = self.AtomicData_options["r_max"] + # + 1 to always have a zero bin at the end + n_bins: int = int(math.ceil(r_max / bin_width)) + 1 + # +1 since these are bin_edges including rightmost + bins = bin_width * np.arange(n_bins + 1) - # TODO: generalize this? - for intkey in ( - AtomicDataDict.ATOMIC_NUMBERS_KEY, - AtomicDataDict.ATOM_TYPE_KEY, - AtomicDataDict.EDGE_INDEX_KEY, + for type1, type2 in itertools.combinations_with_replacement( + range(self.type_mapper.num_types), 2 ): - if intkey in mapped: - mapped[intkey] = mapped[intkey].astype(np.int64) - - fields = {k: v for k, v in mapped.items() if k not in self.npz_fixed_field_keys} - # note that we don't deal with extra_fixed_fields here; AtomicInMemoryDataset does that. - fixed_fields = { - k: v for k, v in mapped.items() if k in self.npz_fixed_field_keys - } - return fields, fixed_fields - - -def _ase_dataset_reader( - rank: int, - world_size: int, - tmpdir: str, - ase_kwargs: dict, - atomicdata_kwargs: dict, - include_frames, - global_options: dict, -) -> Union[str, List[AtomicData]]: - """Parallel reader for all frames in file.""" - if world_size > 1: - from nequip.utils._global_options import _set_global_options - - # ^ avoid import loop - # we only `multiprocessing` if world_size > 1 - _set_global_options(global_options) - # interleave--- in theory it is better for performance for the ranks - # to read consecutive blocks, but the way ASE is written the whole - # file gets streamed through all ranks anyway, so just trust the OS - # to cache things sanely, which it will. - # ASE handles correctly the case where there are no frames in index - # and just gives an empty list, so that will succeed: - index = slice(rank, None, world_size) - if include_frames is None: - # count includes 0, 1, ..., inf - include_frames = itertools.count() - - datas = [] - # stream them from ase too using iread - for i, atoms in enumerate(ase.io.iread(**ase_kwargs, index=index, parallel=False)): - global_index = rank + (world_size * i) - datas.append( - ( - global_index, - AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) - if global_index in include_frames - # in-memory dataset will ignore this later, but needed for indexing to work out - else None, + # Try to do as much of this as possible in-place + mask = types_center == type1 + np.logical_and(mask, types_neigh == type2, out=mask) + np.logical_and(mask, edge_selector, out=mask) + mask = mask.astype(np.int32) + results[(type1, type2)] = np.histogram( + data[AtomicDataDict.EDGE_LENGTH_KEY], + weights=mask, + bins=bins, + density=True, ) - ) - # Save to a tempfile--- - # there can be a _lot_ of tensors here, and rather than dealing with - # the complications of running out of file descriptors and setting - # sharing methods, since this is a one time thing, just make it simple - # and avoid shared memory entirely. - if world_size > 1: - path = f"{tmpdir}/rank{rank}.pth" - torch.save(datas, path) - return path - else: - return datas - - -class ASEDataset(AtomicInMemoryDataset): - """ - - Args: - ase_args (dict): arguments for ase.io.read - include_keys (list): in addition to forces and energy, the keys that needs to - be parsed into dataset - The data stored in ase.atoms.Atoms.array has the lowest priority, - and it will be overrided by data in ase.atoms.Atoms.info - and ase.atoms.Atoms.calc.results. Optional - key_mapping (dict): rename some of the keys to the value str. Optional - - Example: Given an atomic data stored in "H2.extxyz" that looks like below: - - ```H2.extxyz - 2 - Properties=species:S:1:pos:R:3 energy=-10 user_label=2.0 pbc="F F F" - H 0.00000000 0.00000000 0.00000000 - H 0.00000000 0.00000000 1.02000000 - ``` - - The yaml input should be - - ``` - dataset: ase - dataset_file_name: H2.extxyz - ase_args: - format: extxyz - include_keys: - - user_label - key_mapping: - user_label: label0 - chemical_symbols: - - H - ``` - - for VASP parser, the yaml input should be - ``` - dataset: ase - dataset_file_name: OUTCAR - ase_args: - format: vasp-out - key_mapping: - free_energy: total_energy - chemical_symbols: - - H - ``` - - """ - - def __init__( - self, - root: str, - ase_args: dict = {}, - file_name: Optional[str] = None, - url: Optional[str] = None, - force_fixed_keys: List[str] = [], - extra_fixed_fields: Dict[str, Any] = {}, - include_frames: Optional[List[int]] = None, - type_mapper: TypeMapper = None, - key_mapping: Optional[dict] = None, - include_keys: Optional[List[str]] = None, - ): - self.ase_args = {} - self.ase_args.update(getattr(type(self), "ASE_ARGS", dict())) - self.ase_args.update(ase_args) - assert "index" not in self.ase_args - assert "filename" not in self.ase_args - - self.include_keys = include_keys - self.key_mapping = key_mapping - - super().__init__( - file_name=file_name, - url=url, - root=root, - force_fixed_keys=force_fixed_keys, - extra_fixed_fields=extra_fixed_fields, - include_frames=include_frames, - type_mapper=type_mapper, - ) - - @classmethod - def from_atoms_list(cls, atoms: Sequence[ase.Atoms], **kwargs): - """Make an ``ASEDataset`` from a list of ``ase.Atoms`` objects. - - If `root` is not provided, a temporary directory will be used. + # RDF is symmetric + results[(type2, type1)] = results[(type1, type2)] - Please note that this is a convinience method that does NOT avoid a round-trip to disk; the provided ``atoms`` will be written out to a file. - - Ignores ``kwargs["file_name"]`` if it is provided. - - Args: - atoms - **kwargs: passed through to the constructor - Returns: - The constructed ``ASEDataset``. - """ - if "root" not in kwargs: - tmpdir = tempfile.TemporaryDirectory() - kwargs["root"] = tmpdir.name - else: - tmpdir = None - kwargs["file_name"] = tmpdir.name + "/atoms.xyz" - atoms = list(atoms) - # Write them out - ase.io.write(kwargs["file_name"], atoms, format="extxyz") - # Read them in - obj = cls(**kwargs) - if tmpdir is not None: - # Make it keep a reference to the tmpdir to keep it alive - # When the dataset is garbage collected, the tmpdir will - # be too, and will (hopefully) get deleted eventually. - # Or at least by end of program... - obj._tmpdir_ref = tmpdir - return obj - - @property - def raw_file_names(self): - return [basename(self.file_name)] - - @property - def raw_dir(self): - return dirname(abspath(self.file_name)) - - def get_data(self): - ase_args = {"filename": self.raw_dir + "/" + self.raw_file_names[0]} - ase_args.update(self.ase_args) - - # skip the None arguments - kwargs = dict( - include_keys=self.include_keys, - key_mapping=self.key_mapping, - ) - kwargs = {k: v for k, v in kwargs.items() if v is not None} - kwargs.update(self.extra_fixed_fields) - n_proc = num_tasks() - with tempfile.TemporaryDirectory() as tmpdir: - from nequip.utils._global_options import _get_latest_global_options - - # ^ avoid import loop - reader = functools.partial( - _ase_dataset_reader, - world_size=n_proc, - tmpdir=tmpdir, - ase_kwargs=ase_args, - atomicdata_kwargs=kwargs, - include_frames=self.include_frames, - # get the global options of the parent to initialize the worker correctly - global_options=_get_latest_global_options(), - ) - if n_proc > 1: - # things hang for some obscure OpenMP reason on some systems when using `fork` method - ctx = mp.get_context("forkserver") - with ctx.Pool(processes=n_proc) as p: - # map it over the `rank` argument - datas = p.map(reader, list(range(n_proc))) - # clean up the pool before loading the data - datas = [torch.load(d) for d in datas] - datas = sum(datas, []) - # un-interleave the datas - datas = sorted(datas, key=lambda e: e[0]) - else: - datas = reader(rank=0) - # datas here is already in order, stride 1 start 0 - # no need to un-interleave - # return list of AtomicData: - return ([e[1] for e in datas],) + return results diff --git a/nequip/data/_dataset/_hdf5_dataset.py b/nequip/data/_dataset/_hdf5_dataset.py new file mode 100644 index 00000000..5fce39e2 --- /dev/null +++ b/nequip/data/_dataset/_hdf5_dataset.py @@ -0,0 +1,171 @@ +from typing import Dict, Any, List, Callable, Union, Optional +from collections import defaultdict +import numpy as np + +import torch + +from .. import ( + AtomicData, + AtomicDataDict, +) +from ..transforms import TypeMapper +from ._base_datasets import AtomicDataset + + +class HDF5Dataset(AtomicDataset): + """A dataset that loads data from a HDF5 file. + + This class is useful for very large datasets that cannot fit in memory. It + efficiently loads data from disk as needed without everything needing to be + in memory at once. + + To use this, ``file_name`` should point to the HDF5 file, or alternatively a + semicolon separated list of multiple files. Each group in the file contains + samples that all have the same number of atoms. Typically there is one + group for each unique number of atoms, but that is not required. Each group + should contain arrays whose length equals the number of samples, one for each + type of data. The names of the arrays can be specified with ``key_mapping``. + + Args: + key_mapping (Dict[str, str]): mapping of array names in the HDF5 file to ``AtomicData`` keys + file_name (string): a semicolon separated list of HDF5 files. + """ + + def __init__( + self, + root: str, + key_mapping: Dict[str, str] = { + "pos": AtomicDataDict.POSITIONS_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + "forces": AtomicDataDict.FORCE_KEY, + "atomic_numbers": AtomicDataDict.ATOMIC_NUMBERS_KEY, + "types": AtomicDataDict.ATOM_TYPE_KEY, + }, + file_name: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + type_mapper: Optional[TypeMapper] = None, + ): + super().__init__(root=root, type_mapper=type_mapper) + self.key_mapping = key_mapping + self.key_list = list(key_mapping.keys()) + self.value_list = list(key_mapping.values()) + self.file_name = file_name + self.r_max = AtomicData_options["r_max"] + self.index = None + self.num_frames = 0 + import h5py + + files = [h5py.File(f, "r") for f in self.file_name.split(";")] + for file in files: + for group_name in file: + for key in self.key_list: + if key in file[group_name]: + self.num_frames += len(file[group_name][key]) + break + file.close() + + def setup_index(self): + import h5py + + files = [h5py.File(f, "r") for f in self.file_name.split(";")] + self.has_forces = False + self.index = [] + for file in files: + for group_name in file: + group = file[group_name] + values = [None] * len(self.key_list) + samples = 0 + for i, key in enumerate(self.key_list): + if key in group: + values[i] = group[key] + samples = len(values[i]) + for i in range(samples): + self.index.append(tuple(values + [i])) + + def len(self) -> int: + return self.num_frames + + def get(self, idx: int) -> AtomicData: + if self.index is None: + self.setup_index() + data = self.index[idx] + i = data[-1] + args = {"r_max": self.r_max} + for j, value in enumerate(self.value_list): + if data[j] is not None: + args[value] = data[j][i] + return AtomicData.from_points(**args) + + def statistics( + self, + fields: List[Union[str, Callable]], + modes: List[str], + stride: int = 1, + unbiased: bool = True, + kwargs: Optional[Dict[str, dict]] = {}, + ) -> List[tuple]: + assert len(modes) == len(fields) + # TODO: use RunningStats + if len(fields) == 0: + return [] + if self.index is None: + self.setup_index() + results = [] + indices = self.indices() + if stride != 1: + indices = list(indices)[::stride] + for field, mode in zip(fields, modes): + count = 0 + if mode == "rms": + total = 0.0 + elif mode in ("mean_std", "per_atom_mean_std"): + total = [0.0, 0.0] + elif mode == "count": + counts = defaultdict(int) + else: + raise NotImplementedError(f"Analysis mode '{mode}' is not implemented") + for index in indices: + data = self.index[index] + i = data[-1] + if field in self.value_list: + values = data[self.value_list.index(field)][i] + elif callable(field): + values, _ = field(self.get(index)) + values = np.asarray(values) + else: + raise RuntimeError( + f"The field key `{field}` is not present in this dataset" + ) + length = len(values.flatten()) + if length == 1: + values = np.array([values]) + if mode == "rms": + total += np.sum(values * values) + count += length + elif mode == "count": + for v in values: + counts[v] += 1 + else: + if mode == "per_atom_mean_std": + values /= len(data[0][i]) + for v in values: + count += 1 + delta1 = v - total[0] + total[0] += delta1 / count + delta2 = v - total[0] + total[1] += delta1 * delta2 + if mode == "rms": + results.append(torch.tensor((np.sqrt(total / count),))) + elif mode == "count": + values = sorted(counts.keys()) + results.append( + (torch.tensor(values), torch.tensor([counts[v] for v in values])) + ) + else: + results.append( + ( + torch.tensor(total[0]), + torch.tensor(np.sqrt(total[1] / (count - 1))), + ) + ) + return results diff --git a/nequip/data/_dataset/_npz_dataset.py b/nequip/data/_dataset/_npz_dataset.py new file mode 100644 index 00000000..3b28daaf --- /dev/null +++ b/nequip/data/_dataset/_npz_dataset.py @@ -0,0 +1,141 @@ +import numpy as np +from os.path import dirname, basename, abspath +from typing import Dict, Any, List, Optional + + +from .. import AtomicDataDict, _LONG_FIELDS, _NODE_FIELDS, _GRAPH_FIELDS +from ..transforms import TypeMapper +from ._base_datasets import AtomicInMemoryDataset + + +class NpzDataset(AtomicInMemoryDataset): + """Load data from an npz file. + + To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``include_keys``, + or ``npz_fixed_fields_keys``. + + Args: + key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys. Optional + include_keys (list): the attributes to be processed and stored. Optional + npz_fixed_field_keys: the attributes that only have one instance but apply to all frames. Optional + Note that the mapped keys (as determined by the _values_ in ``key_mapping``) should be used in + ``npz_fixed_field_keys``, not the original npz keys from before mapping. If an npz key is not + present in ``key_mapping``, it is mapped to itself, and this point is not relevant. + + Example: Given a npz file with 10 configurations, each with 14 atoms. + + position: (10, 14, 3) + force: (10, 14, 3) + energy: (10,) + Z: (14) + user_label1: (10) # per config + user_label2: (10, 14, 3) # per atom + + The input yaml should be + + ```yaml + dataset: npz + dataset_file_name: example.npz + include_keys: + - user_label1 + - user_label2 + npz_fixed_field_keys: + - cell + - atomic_numbers + key_mapping: + position: pos + force: forces + energy: total_energy + Z: atomic_numbers + graph_fields: + - user_label1 + node_fields: + - user_label2 + ``` + + """ + + def __init__( + self, + root: str, + key_mapping: Dict[str, str] = { + "positions": AtomicDataDict.POSITIONS_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + "force": AtomicDataDict.FORCE_KEY, + "forces": AtomicDataDict.FORCE_KEY, + "Z": AtomicDataDict.ATOMIC_NUMBERS_KEY, + "atomic_number": AtomicDataDict.ATOMIC_NUMBERS_KEY, + }, + include_keys: List[str] = [], + npz_fixed_field_keys: List[str] = [], + file_name: Optional[str] = None, + url: Optional[str] = None, + AtomicData_options: Dict[str, Any] = {}, + include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + ): + self.key_mapping = key_mapping + self.npz_fixed_field_keys = npz_fixed_field_keys + self.include_keys = include_keys + + super().__init__( + file_name=file_name, + url=url, + root=root, + AtomicData_options=AtomicData_options, + include_frames=include_frames, + type_mapper=type_mapper, + ) + + @property + def raw_file_names(self): + return [basename(self.file_name)] + + @property + def raw_dir(self): + return dirname(abspath(self.file_name)) + + def get_data(self): + + data = np.load(self.raw_dir + "/" + self.raw_file_names[0], allow_pickle=True) + + # only the keys explicitly mentioned in the yaml file will be parsed + keys = set(list(self.key_mapping.keys())) + keys.update(self.npz_fixed_field_keys) + keys.update(self.include_keys) + keys = keys.intersection(set(list(data.keys()))) + + mapped = {self.key_mapping.get(k, k): data[k] for k in keys} + + for intkey in _LONG_FIELDS: + if intkey in mapped: + mapped[intkey] = mapped[intkey].astype(np.int64) + + fields = {k: v for k, v in mapped.items() if k not in self.npz_fixed_field_keys} + num_examples, num_atoms, n_dim = fields[AtomicDataDict.POSITIONS_KEY].shape + assert n_dim == 3 + + # now we replicate and add the fixed fields: + for fixed_field in self.npz_fixed_field_keys: + orig = mapped[fixed_field] + if fixed_field in _NODE_FIELDS: + assert orig.ndim >= 1 # [n_atom, feature_dims] + assert orig.shape[0] == num_atoms + replicated = np.expand_dims(orig, 0) + replicated = np.tile( + replicated, + (num_examples,) + (1,) * len(replicated.shape[1:]), + ) # [n_example, n_atom, feature_dims] + elif fixed_field in _GRAPH_FIELDS: + # orig is [feature_dims] + replicated = np.expand_dims(orig, 0) + replicated = np.tile( + replicated, + (num_examples,) + (1,) * len(replicated.shape[1:]), + ) # [n_example, feature_dims] + else: + raise KeyError( + f"npz_fixed_field_keys contains `{fixed_field}`, but it isn't registered as a node or graph field" + ) + fields[fixed_field] = replicated + return fields diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index 54b66ce3..edd04cbe 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -45,6 +45,10 @@ # [n_edge, dim] invariant embedding of the edges EDGE_EMBEDDING_KEY: Final[str] = "edge_embedding" EDGE_FEATURES_KEY: Final[str] = "edge_features" +# [n_edge, 1] invariant of the radial cutoff envelope for each edge, allows reuse of cutoff envelopes +EDGE_CUTOFF_KEY: Final[str] = "edge_cutoff" +# edge energy as in Allegro +EDGE_ENERGY_KEY: Final[str] = "edge_energy" NODE_FEATURES_KEY: Final[str] = "node_features" NODE_ATTRS_KEY: Final[str] = "node_attrs" @@ -57,6 +61,7 @@ VIRIAL_KEY: Final[str] = "virial" ALL_ENERGY_KEYS: Final[List[str]] = [ + EDGE_ENERGY_KEY, PER_ATOM_ENERGY_KEY, TOTAL_ENERGY_KEY, FORCE_KEY, @@ -66,6 +71,7 @@ ] BATCH_KEY: Final[str] = "batch" +BATCH_PTR_KEY: Final[str] = "ptr" # Make a list of allowed keys ALLOWED_KEYS: List[str] = [ diff --git a/nequip/data/_test_data.py b/nequip/data/_test_data.py index e8f4109e..498ba13e 100644 --- a/nequip/data/_test_data.py +++ b/nequip/data/_test_data.py @@ -7,7 +7,7 @@ import ase.build from ase.calculators.emt import EMT -from nequip.data import AtomicInMemoryDataset, AtomicData, AtomicDataDict +from nequip.data import AtomicInMemoryDataset, AtomicData from .transforms import TypeMapper @@ -30,7 +30,7 @@ def __init__( dataset_seed: int = 123456, file_name: Optional[str] = None, url: Optional[str] = None, - extra_fixed_fields: Dict[str, Any] = {}, + AtomicData_options: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, type_mapper: TypeMapper = None, ): @@ -38,7 +38,7 @@ def __init__( assert element in ("Cu", "Pd", "Au", "Pt", "Al", "Ni", "Ag") self.element = element self.sigma = sigma - self.supercell = supercell + self.supercell = tuple(supercell) self.num_frames = num_frames self.dataset_seed = dataset_seed @@ -46,8 +46,7 @@ def __init__( file_name=file_name, url=url, root=root, - force_fixed_keys=[AtomicDataDict.CELL_KEY, AtomicDataDict.PBC_KEY], - extra_fixed_fields=extra_fixed_fields, + AtomicData_options=AtomicData_options, include_frames=include_frames, type_mapper=type_mapper, ) @@ -78,7 +77,7 @@ def get_data(self): forces=base_atoms.get_forces(), total_energy=base_atoms.get_potential_energy(), stress=base_atoms.get_stress(voigt=False), - **self.extra_fixed_fields + **self.AtomicData_options ) ) - return (datas,) + return datas diff --git a/nequip/data/dataloader.py b/nequip/data/dataloader.py index a6c16670..ea9c7fc9 100644 --- a/nequip/data/dataloader.py +++ b/nequip/data/dataloader.py @@ -1,24 +1,22 @@ -from typing import List +from typing import List, Optional, Iterator import torch +from torch.utils.data import Sampler -from nequip.utils.torch_geometric import Batch, Data +from nequip.utils.torch_geometric import Batch, Data, Dataset class Collater(object): """Collate a list of ``AtomicData``. Args: - fixed_fields: which fields are fixed fields exclude_keys: keys to ignore in the input, not copying to the output """ def __init__( self, - fixed_fields: List[str] = [], exclude_keys: List[str] = [], ): - self.fixed_fields = fixed_fields self._exclude_keys = set(exclude_keys) @classmethod @@ -27,35 +25,14 @@ def for_dataset( dataset, exclude_keys: List[str] = [], ): - """Construct a collater appropriate to ``dataset``. - - All kwargs besides ``fixed_fields`` are passed through to the constructor. - """ + """Construct a collater appropriate to ``dataset``.""" return cls( - fixed_fields=list(getattr(dataset, "fixed_fields", {}).keys()), exclude_keys=exclude_keys, ) def collate(self, batch: List[Data]) -> Batch: """Collate a list of data""" - # For fixed fields, we need to batch those that are per-node or - # per-edge, since they need to be repeated in order to have the same - # number of nodes/edges as the full batch graph. - # For fixed fields that are per-example, however — those with __cat_dim__ - # of None — we can just put one copy over the whole batch graph. - # Figure out which ones those are: - new_dim_fixed = set() - for f in self.fixed_fields: - if batch[0].__cat_dim__(f, None) is None: - new_dim_fixed.add(f) - # TODO: cache ^ and the batched versions of fixed fields for various batch sizes if necessary for performance - out = Batch.from_data_list( - batch, exclude_keys=self._exclude_keys.union(new_dim_fixed) - ) - for f in new_dim_fixed: - if f in self._exclude_keys: - continue - out[f] = batch[0][f] + out = Batch.from_data_list(batch, exclude_keys=self._exclude_keys) return out def __call__(self, batch: List[Data]) -> Batch: @@ -86,3 +63,102 @@ def __init__( collate_fn=Collater.for_dataset(dataset, exclude_keys=exclude_keys), **kwargs, ) + + +class PartialSampler(Sampler[int]): + r"""Samples elements without replacement, but divided across a number of calls to `__iter__`. + + To ensure deterministic reproducibility and restartability, dataset permutations are generated + from a combination of the overall seed and the epoch number. As a result, the caller must + tell this sampler the epoch number before each time `__iter__` is called by calling + `my_partial_sampler.step_epoch(epoch_number_about_to_run)` each time. + + This sampler decouples epochs from the dataset size and cycles through the dataset over as + many (partial) epochs as it may take. As a result, the _dataset_ epoch can change partway + through a training epoch. + + Args: + data_source (Dataset): dataset to sample from + shuffle (bool): whether to shuffle the dataset each time the _entire_ dataset is consumed + num_samples_per_epoch (int): number of samples to draw in each call to `__iter__`. + If `None`, defaults to `len(data_source)`. + generator (Generator): Generator used in sampling. + """ + data_source: Dataset + num_samples_per_epoch: int + shuffle: bool + _epoch: int + _prev_epoch: int + + def __init__( + self, + data_source: Dataset, + shuffle: bool = True, + num_samples_per_epoch: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.shuffle = shuffle + if num_samples_per_epoch is None: + num_samples_per_epoch = self.num_samples_total + self.num_samples_per_epoch = num_samples_per_epoch + assert self.num_samples_per_epoch <= self.num_samples_total + assert self.num_samples_per_epoch >= 1 + self.generator = generator + self._epoch = None + self._prev_epoch = None + + @property + def num_samples_total(self) -> int: + # dataset size might change at runtime + return len(self.data_source) + + def step_epoch(self, epoch: int) -> None: + self._epoch = epoch + + def __iter__(self) -> Iterator[int]: + assert self._epoch is not None + assert (self._prev_epoch is None) or (self._epoch == self._prev_epoch + 1) + assert self._epoch >= 0 + + full_epoch_i, start_sample_i = divmod( + # how much data we've already consumed: + self._epoch * self.num_samples_per_epoch, + # how much data there is the dataset: + self.num_samples_total, + ) + + if self.shuffle: + temp_rng = torch.Generator() + # Get new randomness for each _full_ time through the dataset + # This is deterministic w.r.t. the combination of dataset seed and epoch number + # Both of which persist across restarts + # (initial_seed() is restored by set_state()) + temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i) + full_order_this = torch.randperm(self.num_samples_total, generator=temp_rng) + # reseed the generator for the _next_ epoch to get the shuffled order of the + # _next_ dataset epoch to pad out this one for completing any partial batches + # at the end: + temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i + 1) + full_order_next = torch.randperm(self.num_samples_total, generator=temp_rng) + del temp_rng + else: + full_order_this = torch.arange(self.num_samples_total) + # without shuffling, the next epoch has the same sampling order as this one: + full_order_next = full_order_this + + full_order = torch.cat((full_order_this, full_order_next), dim=0) + del full_order_next, full_order_this + + this_segment_indexes = full_order[ + start_sample_i : start_sample_i + self.num_samples_per_epoch + ] + # because we cycle into indexes from the next dataset epoch, + # we should _always_ be able to get num_samples_per_epoch + assert len(this_segment_indexes) == self.num_samples_per_epoch + yield from this_segment_indexes + + self._prev_epoch = self._epoch + + def __len__(self) -> int: + return self.num_samples_per_epoch diff --git a/nequip/data/transforms.py b/nequip/data/transforms.py index 4f6331b7..fc4afe51 100644 --- a/nequip/data/transforms.py +++ b/nequip/data/transforms.py @@ -13,6 +13,7 @@ class TypeMapper: num_types: int chemical_symbol_to_type: Optional[Dict[str, int]] + type_to_chemical_symbol: Optional[Dict[int, str]] type_names: List[str] _min_Z: int @@ -20,6 +21,7 @@ def __init__( self, type_names: Optional[List[str]] = None, chemical_symbol_to_type: Optional[Dict[str, int]] = None, + type_to_chemical_symbol: Optional[Dict[int, str]] = None, chemical_symbols: Optional[List[str]] = None, ): if chemical_symbols is not None: @@ -37,6 +39,14 @@ def __init__( chemical_symbol_to_type = {k: i for i, k in enumerate(chemical_symbols)} del chemical_symbols + if type_to_chemical_symbol is not None: + type_to_chemical_symbol = { + int(k): v for k, v in type_to_chemical_symbol.items() + } + assert all( + v in ase.data.chemical_symbols for v in type_to_chemical_symbol.values() + ) + # Build from chem->type mapping, if provided self.chemical_symbol_to_type = chemical_symbol_to_type if self.chemical_symbol_to_type is not None: @@ -75,6 +85,14 @@ def __init__( for sym, type_idx in self.chemical_symbol_to_type.items(): self._index_to_Z[type_idx] = ase.data.atomic_numbers[sym] self._valid_set = set(valid_atomic_numbers) + true_type_to_chemical_symbol = { + type_id: sym for sym, type_id in self.chemical_symbol_to_type.items() + } + if type_to_chemical_symbol is not None: + assert type_to_chemical_symbol == true_type_to_chemical_symbol + else: + type_to_chemical_symbol = true_type_to_chemical_symbol + # check if type_names is None: raise ValueError( @@ -88,6 +106,9 @@ def __init__( self.num_types = len(type_names) # Check type_names self.type_names = type_names + self.type_to_chemical_symbol = type_to_chemical_symbol + if self.type_to_chemical_symbol is not None: + assert set(type_to_chemical_symbol.keys()) == set(range(self.num_types)) def __call__( self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b79a820c..30031146 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -6,6 +6,8 @@ initialize_from_state, load_model_state, ) +from ._gmm import GaussianMixtureModelUncertainty +from ._pair_potential import PairPotential, PairPotentialTerm from ._build import model_from_config @@ -22,6 +24,9 @@ uniform_initialize_FCs, initialize_from_state, load_model_state, + GaussianMixtureModelUncertainty, model_from_config, + PairPotential, + PairPotentialTerm, builder_utils, ] diff --git a/nequip/model/_build.py b/nequip/model/_build.py index 7e1a63fd..372ea90b 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -1,14 +1,23 @@ import inspect from typing import Optional +import torch + from nequip.data import AtomicDataset from nequip.data.transforms import TypeMapper -from nequip.nn import GraphModuleMixin -from nequip.utils import load_callable, instantiate +from nequip.nn import GraphModuleMixin, GraphModel +from nequip.utils import ( + load_callable, + instantiate, + dtype_from_name, + torch_default_dtype, + Config, +) +from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS def model_from_config( - config, + config: Config, initialize: bool = False, dataset: Optional[AtomicDataset] = None, deploy: bool = False, @@ -22,6 +31,8 @@ def model_from_config( - ``dataset``: if ``initialize`` is True, the dataset - ``deploy``: whether the model object is for deployment / inference + Note that this function temporarily sets ``torch.set_default_dtype()`` and as such is not thread safe. + Args: config initialize (bool): whether ``model_builders`` should be instructed to initialize the model @@ -31,6 +42,8 @@ def model_from_config( Returns: The build model. """ + if isinstance(config, dict): + config = Config.from_dict(config) # Pre-process config type_mapper = None if dataset is not None: @@ -52,51 +65,121 @@ def model_from_config( ), "inconsistant config & dataset" config["num_types"] = type_mapper.num_types config["type_names"] = type_mapper.type_names + config["type_to_chemical_symbol"] = type_mapper.type_to_chemical_symbol + # We added them, so they are by definition valid: + _GLOBAL_ALL_ASKED_FOR_KEYS.update( + {"num_types", "type_names", "type_to_chemical_symbol"} + ) + + default_dtype = torch.get_default_dtype() + model_dtype: torch.dtype = dtype_from_name(config.get("model_dtype", default_dtype)) + config["model_dtype"] = str(model_dtype).lstrip("torch.") + # confirm sanity + assert default_dtype in (torch.float32, torch.float64) + if default_dtype == torch.float32 and model_dtype == torch.float64: + raise ValueError( + "Overall default_dtype=float32, but model_dtype=float64 is a higher precision- change default_dtype to float64" + ) + # temporarily set the default dtype + start_graph_model_builders = None + with torch_default_dtype(model_dtype): + + # Build + builders = [ + load_callable(b, prefix="nequip.model") + for b in config.get("model_builders", []) + ] - # Build - builders = [ - load_callable(b, prefix="nequip.model") - for b in config.get("model_builders", []) - ] - - model = None - - for builder in builders: - pnames = inspect.signature(builder).parameters - params = {} - if "initialize" in pnames: - params["initialize"] = initialize - if "deploy" in pnames: - params["deploy"] = deploy - if "config" in pnames: - params["config"] = config - if "dataset" in pnames: - if "initialize" not in pnames: - raise ValueError("Cannot request dataset without requesting initialize") - if ( - initialize - and pnames["dataset"].default == inspect.Parameter.empty - and dataset is None - ): - raise RuntimeError( - f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`." + model = None + + for builder_i, builder in enumerate(builders): + pnames = inspect.signature(builder).parameters + params = {} + if "graph_model" in pnames: + # start graph_model builders, which happen later + start_graph_model_builders = builder_i + break + if "initialize" in pnames: + params["initialize"] = initialize + if "deploy" in pnames: + params["deploy"] = deploy + if "config" in pnames: + params["config"] = config + if "dataset" in pnames: + if "initialize" not in pnames: + raise ValueError( + "Cannot request dataset without requesting initialize" + ) + if ( + initialize + and pnames["dataset"].default == inspect.Parameter.empty + and dataset is None + ): + raise RuntimeError( + f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`." + ) + params["dataset"] = dataset + if "model" in pnames: + if model is None: + raise RuntimeError( + f"Builder {builder.__name__} asked for the model as an input, but no previous builder has returned a model" + ) + params["model"] = model + else: + if model is not None: + raise RuntimeError( + f"All model_builders after the first one that returns a model must take the model as an argument; {builder.__name__} doesn't" + ) + model = builder(**params) + if model is not None and not isinstance(model, GraphModuleMixin): + raise TypeError( + f"Builder {builder.__name__} didn't return a GraphModuleMixin, got {type(model)} instead" ) - params["dataset"] = dataset - if "model" in pnames: - if model is None: - raise RuntimeError( - f"Builder {builder.__name__} asked for the model as an input, but no previous builder has returned a model" + # reset to default dtype by context manager + + # Wrap the model up + model = GraphModel( + model, + model_dtype=model_dtype, + model_input_fields=config.get("model_input_fields", {}), + ) + + # Run GraphModel builders + if start_graph_model_builders is not None: + for builder in builders[start_graph_model_builders:]: + pnames = inspect.signature(builder).parameters + params = {} + assert "graph_model" in pnames + params["graph_model"] = model + if "model" in pnames: + raise ValueError( + f"Once any builder requests `graph_model` (first requested by {builders[start_graph_model_builders].__name__}), no builder can request `model`, but {builder.__name__} did" ) - params["model"] = model - else: - if model is not None: - raise RuntimeError( - f"All model_builders after the first one that returns a model must take the model as an argument; {builder.__name__} doesn't" + if "initialize" in pnames: + params["initialize"] = initialize + if "deploy" in pnames: + params["deploy"] = deploy + if "config" in pnames: + params["config"] = config + if "dataset" in pnames: + if "initialize" not in pnames: + raise ValueError( + "Cannot request dataset without requesting initialize" + ) + if ( + initialize + and pnames["dataset"].default == inspect.Parameter.empty + and dataset is None + ): + raise RuntimeError( + f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`." + ) + params["dataset"] = dataset + + model = builder(**params) + if not isinstance(model, GraphModel): + raise TypeError( + f"Builder {builder.__name__} didn't return a GraphModel, got {type(model)} instead" ) - model = builder(**params) - if model is not None and not isinstance(model, GraphModuleMixin): - raise TypeError( - f"Builder {builder.__name__} didn't return a GraphModuleMixin, got {type(model)} instead" - ) return model diff --git a/nequip/model/_gmm.py b/nequip/model/_gmm.py new file mode 100644 index 00000000..196ab360 --- /dev/null +++ b/nequip/model/_gmm.py @@ -0,0 +1,96 @@ +from typing import Optional + +from tqdm.auto import tqdm + +import torch + +from nequip.nn import GraphModel, SequentialGraphNetwork +from nequip.nn import ( + GaussianMixtureModelUncertainty as GaussianMixtureModelUncertaintyModule, +) +from nequip.data import AtomicDataDict, AtomicData, AtomicDataset, Collater +from nequip.utils import find_first_of_type + + +def GaussianMixtureModelUncertainty( + graph_model: GraphModel, + config, + deploy: bool, + initialize: bool, + dataset: Optional[AtomicDataset] = None, + feature_field: str = AtomicDataDict.NODE_FEATURES_KEY, + out_field: Optional[str] = None, +): + r"""Use a GMM on some latent features to predict an uncertainty. + + Only for deployment time! See `configs/minimal_gmm.yaml`. + """ + # it only makes sense to add or fit a GMM to a deployment model whose features are already trained + if (not deploy) or initialize: + raise RuntimeError( + "GaussianMixtureModelUncertainty can only be used at deployment time, see `configs/minimal_gmm.yaml`." + ) + + # = add GMM = + if out_field is None: + out_field = feature_field + "_nll" + + # TODO: this is VERY brittle!!!! + seqnn: SequentialGraphNetwork = find_first_of_type( + graph_model, SequentialGraphNetwork + ) + + gmm: GaussianMixtureModelUncertaintyModule = seqnn.append_from_parameters( + builder=GaussianMixtureModelUncertaintyModule, + name=feature_field + "_gmm", + shared_params=config, + params=dict(feature_field=feature_field, out_field=out_field), + ) + + if dataset is None: + raise RuntimeError( + "GaussianMixtureModelUncertainty requires a dataset to fit the GMM on; did you specify `nequip-deploy --using-dataset`?" + ) + + # = evaluate features = + # set up model + prev_training: bool = graph_model.training + prev_device: torch.device = graph_model.get_device() + device = config.get("device", None) + graph_model.eval() + graph_model.to(device=device) + # evaluate + features = [] + collater = Collater.for_dataset(dataset=dataset) + batch_size: int = config.get("validation_batch_size", config.batch_size) + stride: int = config.get("dataset_statistics_stride", 1) + # TODO: guard TQDM on interactive? + for batch_start_i in tqdm( + range(0, len(dataset), stride * batch_size), + desc="GMM eval features on train set", + ): + batch = collater( + [dataset[batch_start_i + i * stride] for i in range(batch_size)] + ) + # TODO: !! assumption that final value of feature_field is what the + # GMM gets is very brittle, should really be extracting it + # from the GMM module somehow... not sure how that works. + # give it a training mode and exfiltrate it through a buffer? + # it is correct, however, for NequIP and Allegro energy models + features.append( + graph_model(AtomicData.to_AtomicDataDict(batch.to(device=device)))[ + feature_field + ] + .detach() + .to("cpu") # offload to not run out of GPU RAM + ) + features = torch.cat(features, dim=0) + assert features.ndim == 2 + # restore model + graph_model.train(mode=prev_training) + graph_model.to(device=prev_device) + # fit GMM + gmm.fit(features, seed=config["seed"]) + del features + + return graph_model diff --git a/nequip/model/_pair_potential.py b/nequip/model/_pair_potential.py new file mode 100644 index 00000000..25faf84a --- /dev/null +++ b/nequip/model/_pair_potential.py @@ -0,0 +1,39 @@ +from nequip.nn import SequentialGraphNetwork, AtomwiseReduce +from nequip.nn.embedding import AddRadialCutoffToData +from nequip.data import AtomicDataDict +from nequip.nn.pair_potential import SimpleLennardJones, LennardJones, ZBL + +_PAIR_STYLES = {"LJ": SimpleLennardJones, "LJ_fancy": LennardJones, "ZBL": ZBL} + + +def PairPotentialTerm( + model: SequentialGraphNetwork, + config, +) -> SequentialGraphNetwork: + assert isinstance(model, SequentialGraphNetwork) + + model.insert_from_parameters( + shared_params=config, + name="pair_potential", + builder=_PAIR_STYLES[config.pair_style], + before="total_energy_sum", + ) + return model + + +def PairPotential(config) -> SequentialGraphNetwork: + return SequentialGraphNetwork.from_parameters( + shared_params=config, + layers={ + "cutoff": AddRadialCutoffToData, + "pair_potential": _PAIR_STYLES[config.pair_style], + "total_energy_sum": ( + AtomwiseReduce, + dict( + reduce="sum", + field=AtomicDataDict.PER_ATOM_ENERGY_KEY, + out_field=AtomicDataDict.TOTAL_ENERGY_KEY, + ), + ), + }, + ) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 8a7ffa46..d1faaa88 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -23,14 +23,14 @@ def RescaleEnergyEtc( dataset=dataset, initialize=initialize, module_prefix="global_rescale", - default_scale=f"dataset_{AtomicDataDict.FORCE_KEY}_rms" - if AtomicDataDict.FORCE_KEY in model.irreps_out - else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", + default_scale=( + f"dataset_{AtomicDataDict.FORCE_KEY}_rms" + if AtomicDataDict.FORCE_KEY in model.irreps_out + else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std" + ), default_shift=None, default_scale_keys=AtomicDataDict.ALL_ENERGY_KEYS, default_shift_keys=[AtomicDataDict.TOTAL_ENERGY_KEY], - default_related_scale_keys=[AtomicDataDict.PER_ATOM_ENERGY_KEY], - default_related_shift_keys=[], ) @@ -43,8 +43,6 @@ def GlobalRescale( default_shift: Union[str, float, list], default_scale_keys: list, default_shift_keys: list, - default_related_scale_keys: list, - default_related_shift_keys: list, dataset: Optional[AtomicDataset] = None, ): """Add global rescaling for energy(-based quantities). @@ -113,8 +111,6 @@ def GlobalRescale( error_string = "keys need to be a list" assert isinstance(default_scale_keys, list), error_string assert isinstance(default_shift_keys, list), error_string - assert isinstance(default_related_scale_keys, list), error_string - assert isinstance(default_related_shift_keys, list), error_string # == Build the model == return RescaleOutput( @@ -123,10 +119,9 @@ def GlobalRescale( scale_by=global_scale, shift_keys=[k for k in default_shift_keys if k in model.irreps_out], shift_by=global_shift, - related_scale_keys=default_related_scale_keys, - related_shift_keys=default_related_shift_keys, shift_trainable=config.get(f"{module_prefix}_shift_trainable", False), scale_trainable=config.get(f"{module_prefix}_scale_trainable", False), + default_dtype=config.get("default_dtype", None), ) @@ -136,42 +131,60 @@ def PerSpeciesRescale( initialize: bool, dataset: Optional[AtomicDataset] = None, ): - """Add global rescaling for energy(-based quantities). - - If ``initialize`` is false, doesn't compute statistics. - """ + """Add per-atom rescaling (and shifting) for per-atom energies.""" module_prefix = "per_species_rescale" - # = Determine energy rescale type = - scales = config.get( - module_prefix + "_scales", - f"dataset_{AtomicDataDict.FORCE_KEY}_rms" - # if `train_on_keys` isn't provided, assume conservatively - # that we aren't "training" on anything (i.e. take the - # most general defaults) - if AtomicDataDict.FORCE_KEY in config.get("train_on_keys", []) - else f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", - ) - shifts = config.get( - module_prefix + "_shifts", - f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_mean", - ) - # Check for common double shift mistake with defaults if "RescaleEnergyEtc" in config.get("model_builders", []): # if the defaults are enabled, then we will get bad double shift # THIS CHECK IS ONLY GOOD ENOUGH FOR EMITTING WARNINGS has_global_shift = config.get("global_rescale_shift", None) is not None if has_global_shift: - if shifts is not None: + if config.get(module_prefix + "_shifts", True) is not None: # using default of per_atom shift raise RuntimeError( "A global_rescale_shift was provided, but the default per-atom energy shift was not disabled." ) del has_global_shift - # = Determine what statistics need to be compute =\ - arguments_in_dataset_units = None + return _PerSpeciesRescale( + scales_default=None, + shifts_default=f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_mean", + field=AtomicDataDict.PER_ATOM_ENERGY_KEY, + out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY, + module_prefix=module_prefix, + insert_before="total_energy_sum", + model=model, + config=config, + initialize=initialize, + dataset=dataset, + ) + + +def _PerSpeciesRescale( + scales_default, + shifts_default, + field: str, + out_field: str, + module_prefix: str, + insert_before: str, + model: GraphModuleMixin, + config, + initialize: bool, + dataset: Optional[AtomicDataset] = None, +): + """Add per-atom rescaling (and shifting) for a field + + If ``initialize`` is false, doesn't compute statistics. + """ + scales = config.get(module_prefix + "_scales", scales_default) + shifts = config.get(module_prefix + "_shifts", shifts_default) + + # = Determine what statistics need to be compute = + assert config.get( + module_prefix + "_arguments_in_dataset_units", True + ), f"The PerSpeciesRescale builder is only compatible with {module_prefix + '_arguments_in_dataset_units'} set to True" + if initialize: str_names = [] for value in [scales, shifts]: @@ -188,20 +201,6 @@ def PerSpeciesRescale( else: raise ValueError(f"Invalid value `{value}` of type {type(value)}") - if len(str_names) == 2: - # Both computed from dataset - arguments_in_dataset_units = True - elif len(str_names) == 1: - if None in [scales, shifts]: - # if the one that isnt str is null, it's just disabled - # that has no units - # so it's ok to have just one and to be in dataset units - arguments_in_dataset_units = True - else: - assert config[ - module_prefix + "_arguments_in_dataset_units" - ], "Requested to set either the shifts or scales of the per_species_rescale using dataset values, but chose to provide the other in non-dataset units. Please give the explictly specified shifts/scales in dataset units and set per_species_rescale_arguments_in_dataset_units" - # = Compute shifts and scales = if len(str_names) > 0: computed_stats = _compute_stats( @@ -213,21 +212,24 @@ def PerSpeciesRescale( if isinstance(scales, str): s = scales - scales = computed_stats[str_names.index(scales)].squeeze(-1) # energy is 1D + # energy or other property is 1D: + scales = computed_stats[str_names.index(scales)].squeeze(-1) logging.info(f"Replace string {s} to {scales}") elif isinstance(scales, (list, float)): scales = torch.as_tensor(scales) if isinstance(shifts, str): s = shifts - shifts = computed_stats[str_names.index(shifts)].squeeze(-1) # energy is 1D + # energy or other property is 1D: + shifts = computed_stats[str_names.index(shifts)].squeeze(-1) logging.info(f"Replace string {s} to {shifts}") elif isinstance(shifts, (list, float)): shifts = torch.as_tensor(shifts) + # TODO kind of weird error to check for here if scales is not None and torch.min(scales) < RESCALE_THRESHOLD: raise ValueError( - f"Per species energy scaling was very low: {scales}. Maybe try setting {module_prefix}_scales = 1." + f"Per species scaling was very low: {scales}. Maybe try setting {module_prefix}_scales = 1." ) logging.info( @@ -241,22 +243,20 @@ def PerSpeciesRescale( # so this is fine regardless of whether its trainable. scales = 1.0 if scales is not None else None shifts = 0.0 if shifts is not None else None - # values correctly scaled according to where the come from - # will be brought from the state dict later, - # so what you set this to doesnt matter: - arguments_in_dataset_units = False + # values from the previously initialized model + # will be brought in from the state dict later, + # so these values (and rescaling them) doesn't matter # insert in per species shift params = dict( - field=AtomicDataDict.PER_ATOM_ENERGY_KEY, - out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY, + field=field, + out_field=out_field, shifts=shifts, scales=scales, + arguments_in_dataset_units=True, ) - - params["arguments_in_dataset_units"] = arguments_in_dataset_units model.insert_from_parameters( - before="total_energy_sum", + before=insert_before, name=module_prefix, shared_params=config, builder=PerSpeciesScaleShift, @@ -288,7 +288,7 @@ def _compute_stats( stat_strs = [] ids = [] tuple_ids = [] - tuple_id_map = {"mean": 0, "std": 1, "rms": 0} + tuple_id_map = {"mean": 0, "std": 1, "rms": 0, "absmax": 0} input_kwargs = {} for name in str_names: @@ -309,9 +309,9 @@ def _compute_stats( if stat in ["mean", "std"]: stat_mode = prefix + "mean_std" stat_str = field + prefix + "mean_std" - elif stat in ["rms"]: - stat_mode = prefix + "rms" - stat_str = field + prefix + "rms" + elif stat in ["rms", "absmax"]: + stat_mode = prefix + stat + stat_str = field + prefix + stat else: raise ValueError(f"Cannot handle {stat} type quantity") diff --git a/nequip/model/_weight_init.py b/nequip/model/_weight_init.py index 7d6184c4..31d9624d 100644 --- a/nequip/model/_weight_init.py +++ b/nequip/model/_weight_init.py @@ -1,16 +1,19 @@ import math +import logging import torch import e3nn.o3 import e3nn.nn -from nequip.nn import GraphModuleMixin +from nequip.nn import GraphModuleMixin, GraphModel from nequip.utils import Config # == Load old state == -def initialize_from_state(config: Config, model: GraphModuleMixin, initialize: bool): +def initialize_from_state( + config: Config, graph_model: GraphModel, initialize: bool +) -> GraphModel: """Initialize the model from the state dict file given by the config options `initial_model_state`. Only loads the state dict if `initialize` is `True`; this is meant for, say, starting a training from a previous state. @@ -22,18 +25,21 @@ def initialize_from_state(config: Config, model: GraphModuleMixin, initialize: b See https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict. """ if not initialize: - return model # do nothing + return graph_model # do nothing return load_model_state( - config=config, model=model, initialize=initialize, _prefix="initial_model_state" + config=config, + graph_model=graph_model, + initialize=initialize, + _prefix="initial_model_state", ) def load_model_state( config: Config, - model: GraphModuleMixin, + graph_model: GraphModel, initialize: bool, _prefix: str = "load_model_state", -): +) -> GraphModel: """Load the model from the state dict file given by the config options `load_model_state`. Loads the state dict always; this is meant, for example, for building a new model to deploy with a given state dict. @@ -48,9 +54,16 @@ def load_model_state( raise KeyError( f"initialize_from_state requires the `{_prefix}` option specifying the state to initialize from" ) - state = torch.load(config[_prefix]) - model.load_state_dict(state, strict=config.get(_prefix + "_strict", True)) - return model + # Make sure we map to CPU if there is no GPU, otherwise just leave it alone + state = torch.load( + config[_prefix], map_location=None if torch.cuda.is_available() else "cpu" + ) + strict: bool = config.get(_prefix + "_strict", True) + graph_model.load_state_dict(state, strict=strict) + logging.info( + f"Loaded model state {'' if strict else ' with strict=False'} (parameters/weights/persistent buffers) from state {_prefix}={config[_prefix]}" + ) + return graph_model # == Init functions == diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index 10cebee6..6585e698 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -1,13 +1,34 @@ -from ._graph_mixin import GraphModuleMixin, SequentialGraphNetwork # noqa: F401 -from ._atomwise import ( # noqa: F401 +from ._graph_mixin import GraphModuleMixin, SequentialGraphNetwork +from ._graph_model import GraphModel +from ._atomwise import ( AtomwiseOperation, AtomwiseReduce, AtomwiseLinear, PerSpeciesScaleShift, -) # noqa: F401 -from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import GradientOutput, PartialForceOutput, StressOutput # noqa: F401 -from ._rescale import RescaleOutput # noqa: F401 -from ._convnetlayer import ConvNetLayer # noqa: F401 -from ._util import SaveForOutput # noqa: F401 -from ._concat import Concat # noqa: F401 +) +from ._interaction_block import InteractionBlock +from ._grad_output import GradientOutput, PartialForceOutput, StressOutput +from ._rescale import RescaleOutput +from ._convnetlayer import ConvNetLayer +from ._util import SaveForOutput +from ._concat import Concat +from ._gmm import GaussianMixtureModelUncertainty + +__all__ = [ + GraphModel, + GraphModuleMixin, + SequentialGraphNetwork, + AtomwiseOperation, + AtomwiseReduce, + AtomwiseLinear, + PerSpeciesScaleShift, + InteractionBlock, + GradientOutput, + PartialForceOutput, + StressOutput, + RescaleOutput, + ConvNetLayer, + SaveForOutput, + Concat, + GaussianMixtureModelUncertainty, +] diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 6b7a2ecd..b4020ec5 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -9,7 +9,10 @@ from nequip.data import AtomicDataDict from nequip.data.transforms import TypeMapper +from nequip.utils import dtype_from_name +from nequip.utils.versions import _TORCH_IS_GE_1_13 from ._graph_mixin import GraphModuleMixin +from ._rescale import RescaleOutput class AtomwiseOperation(GraphModuleMixin, torch.nn.Module): @@ -80,24 +83,42 @@ def __init__( self.out_field = f"{reduce}_{field}" if out_field is None else out_field self._init_irreps( irreps_in=irreps_in, - irreps_out={self.out_field: irreps_in[self.field]} - if self.field in irreps_in - else {}, + irreps_out=( + {self.out_field: irreps_in[self.field]} + if self.field in irreps_in + else {} + ), ) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - data = AtomicDataDict.with_batch(data) - data[self.out_field] = scatter( - data[self.field], data[AtomicDataDict.BATCH_KEY], dim=0, reduce=self.reduce - ) + field = data[self.field] + if AtomicDataDict.BATCH_KEY in data: + result = scatter( + field, + data[AtomicDataDict.BATCH_KEY], + dim=0, + dim_size=len(data[AtomicDataDict.BATCH_PTR_KEY]) - 1, + reduce=self.reduce, + ) + else: + # We can significantly simplify and avoid scatters + if self.reduce == "sum": + result = field.sum(dim=0, keepdim=True) + elif self.reduce == "mean": + result = field.mean(dim=0, keepdim=True) + else: + assert False if self.constant != 1.0: - data[self.out_field] = data[self.out_field] * self.constant + result = result * self.constant + data[self.out_field] = result return data class PerSpeciesScaleShift(GraphModuleMixin, torch.nn.Module): """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. + Note that scaling/shifting is always done (casting into) ``default_dtype``, even if ``model_dtype`` is lower precision. + Args: field: the per-atom field to scale/shift. num_types: the number of types in the model. @@ -119,6 +140,8 @@ class PerSpeciesScaleShift(GraphModuleMixin, torch.nn.Module): shifts_trainable: bool has_scales: bool has_shifts: bool + default_dtype: torch.dtype + _use_fma: bool def __init__( self, @@ -131,6 +154,7 @@ def __init__( out_field: Optional[str] = None, scales_trainable: bool = False, shifts_trainable: bool = False, + default_dtype: Optional[str] = None, irreps_in={}, ): super().__init__() @@ -144,53 +168,101 @@ def __init__( irreps_out={self.out_field: irreps_in[self.field]}, ) + self.default_dtype = dtype_from_name( + torch.get_default_dtype() if default_dtype is None else default_dtype + ) + self.has_shifts = shifts is not None if shifts is not None: - shifts = torch.as_tensor(shifts, dtype=torch.get_default_dtype()) + shifts = torch.as_tensor(shifts, dtype=self.default_dtype) if len(shifts.reshape([-1])) == 1: - shifts = torch.ones(num_types) * shifts + shifts = ( + torch.ones(num_types, dtype=shifts.dtype, device=shifts.device) + * shifts + ) assert shifts.shape == (num_types,), f"Invalid shape of shifts {shifts}" self.shifts_trainable = shifts_trainable if shifts_trainable: self.shifts = torch.nn.Parameter(shifts) else: self.register_buffer("shifts", shifts) + else: + self.register_buffer("shifts", torch.Tensor()) self.has_scales = scales is not None if scales is not None: - scales = torch.as_tensor(scales, dtype=torch.get_default_dtype()) + scales = torch.as_tensor(scales, dtype=self.default_dtype) if len(scales.reshape([-1])) == 1: - scales = torch.ones(num_types) * scales + scales = ( + torch.ones(num_types, dtype=scales.dtype, device=scales.device) + * scales + ) assert scales.shape == (num_types,), f"Invalid shape of scales {scales}" self.scales_trainable = scales_trainable if scales_trainable: self.scales = torch.nn.Parameter(scales) else: self.register_buffer("scales", scales) + else: + self.register_buffer("scales", torch.Tensor()) + assert isinstance(arguments_in_dataset_units, bool) self.arguments_in_dataset_units = arguments_in_dataset_units + # we can use FMA for performance but its type promotion is broken until 1.13 + self._use_fma = _TORCH_IS_GE_1_13 + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if not (self.has_scales or self.has_shifts): return data - species_idx = data[AtomicDataDict.ATOM_TYPE_KEY] + species_idx = data[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) in_field = data[self.field] assert len(in_field) == len( species_idx ), "in_field doesnt seem to have correct per-atom shape" - if self.has_scales: - in_field = self.scales[species_idx].view(-1, 1) * in_field - if self.has_shifts: - in_field = self.shifts[species_idx].view(-1, 1) + in_field + + if self._use_fma and self.has_scales and self.has_shifts: + # we can used an FMA for performance + # addcmul computes + # input + tensor1 * tensor2 elementwise + # it will promote to widest dtype, which comes from shifts/scales + in_field = torch.addcmul( + torch.index_select(self.shifts, 0, species_idx).view(-1, 1), + torch.index_select(self.scales, 0, species_idx).view(-1, 1), + in_field, + ) + else: + # fallback path for torch<1.13 OR mix of enabled shifts and scales + # multiplication / addition promotes dtypes already, so no cast is needed + # this is specifically because self.*[species_idx].view(-1, 1) + # is never a scalar (ndim == 0), since it is always [n_atom, 1] + if self.has_scales: + in_field = ( + torch.index_select(self.scales, 0, species_idx).view(-1, 1) + * in_field + ) + if self.has_shifts: + in_field = ( + torch.index_select(self.shifts, 0, species_idx).view(-1, 1) + + in_field + ) data[self.out_field] = in_field return data - def update_for_rescale(self, rescale_module): - if hasattr(rescale_module, "related_scale_keys"): - if self.out_field not in rescale_module.related_scale_keys: - return + def update_for_rescale(self, rescale_module: RescaleOutput): + if not self.arguments_in_dataset_units: + # nothing to rescale, arguments are in normalized units already / unitless + return + # are we scaling something related to the global rescaling? + if self.field not in rescale_module.scale_keys: + return + # now check that we have the right rescaling in the specific energy case + if self.field == AtomicDataDict.PER_ATOM_ENERGY_KEY and not ( + set(rescale_module.scale_keys) <= set(AtomicDataDict.ALL_ENERGY_KEYS) + ): + raise AssertionError("Some unsupported energy scaling arangement...") if self.arguments_in_dataset_units and rescale_module.has_scale: logging.debug( f"PerSpeciesScaleShift's arguments were in dataset units; rescaling:\n " diff --git a/nequip/nn/_convnetlayer.py b/nequip/nn/_convnetlayer.py index 8d3d0dad..9e5437a8 100644 --- a/nequip/nn/_convnetlayer.py +++ b/nequip/nn/_convnetlayer.py @@ -39,8 +39,8 @@ def __init__( num_layers: int = 3, resnet: bool = False, nonlinearity_type: str = "gate", - nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"}, - nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"}, + nonlinearity_scalars: Dict[int, Callable] = {"e": "silu", "o": "tanh"}, + nonlinearity_gates: Dict[int, Callable] = {"e": "silu", "o": "tanh"}, ): super().__init__() # initialization diff --git a/nequip/nn/_gmm.py b/nequip/nn/_gmm.py new file mode 100644 index 00000000..51882dcd --- /dev/null +++ b/nequip/nn/_gmm.py @@ -0,0 +1,61 @@ +from typing import Optional + +import torch + +from e3nn import o3 + +from nequip.data import AtomicDataDict + + +from ._graph_mixin import GraphModuleMixin +from nequip.utils.gmm import GaussianMixture + + +class GaussianMixtureModelUncertainty(GraphModuleMixin, torch.nn.Module): + """Compute GMM NLL uncertainties based on some input featurization. + + Args: + gmm_n_components (int or None): if None, use the BIC to determine the number of components. + """ + + feature_field: str + out_field: str + + def __init__( + self, + feature_field: str, + out_field: str, + gmm_n_components: Optional[int] = None, + gmm_covariance_type: str = "full", + irreps_in=None, + ): + super().__init__() + self.feature_field = feature_field + self.out_field = out_field + self._init_irreps( + irreps_in=irreps_in, + required_irreps_in=[feature_field], + irreps_out={out_field: "0e"}, + ) + feature_irreps = self.irreps_in[self.feature_field].simplify() + if not (len(feature_irreps) == 1 and feature_irreps[0].ir == o3.Irrep("0e")): + raise ValueError( + f"GaussianMixtureModelUncertainty feature_field={feature_field} must be only scalars, instead got {feature_irreps}" + ) + # GaussianMixture already correctly registers things as parameters, + # so they will get saved & loaded in state dicts + self.gmm = GaussianMixture( + n_components=gmm_n_components, + n_features=feature_irreps.num_irreps, + covariance_type=gmm_covariance_type, + ) + + @torch.jit.unused + def fit(self, X, seed=None) -> None: + self.gmm.fit(X, rng=seed) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + if self.gmm.is_fit(): + nll_scores = self.gmm(data[self.feature_field]) + data[self.out_field] = nll_scores + return data diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 673f8ff0..ee0ce6f9 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -1,5 +1,4 @@ from typing import List, Union, Optional -import warnings import torch @@ -193,10 +192,6 @@ def __init__( ): super().__init__() - warnings.warn( - "!! Stresses in NequIP are in BETA and UNDER DEVELOPMENT: _please_ carefully check the sanity of your results and report any (potential) issues on the GitHub" - ) - if not do_forces: raise NotImplementedError self.do_forces = do_forces @@ -209,17 +204,23 @@ def __init__( irreps_out=self.func.irreps_out.copy(), ) self.irreps_out[AtomicDataDict.FORCE_KEY] = "1o" - self.irreps_out[AtomicDataDict.STRESS_KEY] = "3x1o" - self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "3x1o" + self.irreps_out[AtomicDataDict.STRESS_KEY] = "1o" + self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "1o" # for torchscript compat self.register_buffer("_empty", torch.Tensor()) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - data = AtomicDataDict.with_batch(data) + assert AtomicDataDict.EDGE_VECTORS_KEY not in data + + if AtomicDataDict.BATCH_KEY in data: + batch = data[AtomicDataDict.BATCH_KEY] + num_batch: int = len(data[AtomicDataDict.BATCH_PTR_KEY]) - 1 + else: + # Special case for efficiency + batch = self._empty + num_batch: int = 1 - batch = data[AtomicDataDict.BATCH_KEY] - num_batch: int = int(batch.max().cpu().item()) + 1 pos = data[AtomicDataDict.POSITIONS_KEY] has_cell: bool = AtomicDataDict.CELL_KEY in data @@ -243,10 +244,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # Knuth et. al. Comput. Phys. Commun 190, 33-50, 2015 # https://pure.mpg.de/rest/items/item_2085135_9/component/file_2156800/content displacement = torch.zeros( - (num_batch, 3, 3), + (3, 3), dtype=pos.dtype, device=pos.device, ) + if num_batch > 1: + # add n_batch dimension + displacement = displacement.view(-1, 3, 3).expand(num_batch, 3, 3) displacement.requires_grad_(True) data["_displacement"] = displacement # in the above paper, the infinitesimal distortion is *symmetric* @@ -263,10 +267,18 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2)) did_pos_req_grad: bool = pos.requires_grad pos.requires_grad_(True) - # bmm is natom in batch - data[AtomicDataDict.POSITIONS_KEY] = pos + torch.bmm( - pos.unsqueeze(-2), symmetric_displacement[batch] - ).squeeze(-2) + if num_batch > 1: + # bmm is natom in batch + # batched [natom, 1, 3] @ [natom, 3, 3] -> [natom, 1, 3] -> [natom, 3] + data[AtomicDataDict.POSITIONS_KEY] = pos + torch.bmm( + pos.unsqueeze(-2), torch.index_select(symmetric_displacement, 0, batch) + ).squeeze(-2) + else: + # [natom, 3] @ [3, 3] -> [natom, 3] + data[AtomicDataDict.POSITIONS_KEY] = torch.addmm( + pos, pos, symmetric_displacement + ) + # assert torch.equal(pos, data[AtomicDataDict.POSITIONS_KEY]) # we only displace the cell if we have one: if has_cell: # bmm is num_batch in batch @@ -276,9 +288,18 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # there would then be an infinitesimal rotation of the positions # but not cell, and it thus wouldn't be global and have # no effect due to equivariance/invariance. - data[AtomicDataDict.CELL_KEY] = cell + torch.bmm( - cell, symmetric_displacement - ) + if num_batch > 1: + # [n_batch, 3, 3] @ [n_batch, 3, 3] + data[AtomicDataDict.CELL_KEY] = cell + torch.bmm( + cell, symmetric_displacement + ) + else: + # [3, 3] @ [3, 3] --- enforced to these shapes + tmpcell = cell.squeeze(0) + data[AtomicDataDict.CELL_KEY] = torch.addmm( + tmpcell, tmpcell, symmetric_displacement + ).unsqueeze(0) + # assert torch.equal(cell, data[AtomicDataDict.CELL_KEY]) # Call model and get gradients data = self.func(data) @@ -302,19 +323,21 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if virial is None: # condition needed to unwrap optional for torchscript assert False, "failed to compute virial autograd" + virial = virial.view(num_batch, 3, 3) # we only compute the stress (1/V * virial) if we have a cell whose volume we can compute if has_cell: # ^ can only scale by cell volume if we have one...: # Rescale stress tensor # See https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/atomistic/output_modules.py#L180 + # See also https://en.wikipedia.org/wiki/Triple_product + # See also https://gitlab.com/ase/ase/-/blob/master/ase/cell.py, + # which uses np.abs(np.linalg.det(cell)) # First dim is batch, second is vec, third is xyz - volume = torch.einsum( - "zi,zi->z", - cell[:, 0, :], - torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), - ).unsqueeze(-1) - stress = virial / volume.view(-1, 1, 1) + # Note the .abs(), since volume should always be positive + # det is equal to a dot (b cross c) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + stress = virial / volume.view(num_batch, 1, 1) data[AtomicDataDict.CELL_KEY] = orig_cell else: stress = self._empty # torchscript diff --git a/nequip/nn/_graph_mixin.py b/nequip/nn/_graph_mixin.py index 2eb1f64b..eef7d571 100644 --- a/nequip/nn/_graph_mixin.py +++ b/nequip/nn/_graph_mixin.py @@ -217,6 +217,7 @@ def from_parameters( OrderedDict(zip(layers.keys(), built_modules)), ) + @torch.jit.unused def append(self, name: str, module: GraphModuleMixin) -> None: r"""Append a module to the SequentialGraphNetwork. @@ -229,13 +230,14 @@ def append(self, name: str, module: GraphModuleMixin) -> None: self.irreps_out = dict(module.irreps_out) return + @torch.jit.unused def append_from_parameters( self, shared_params: Mapping, name: str, builder: Callable, params: Dict[str, Any] = {}, - ) -> None: + ) -> GraphModuleMixin: r"""Build a module from parameters and append it. Args: @@ -243,6 +245,9 @@ def append_from_parameters( name (str): the name for the module builder (callable): a class or function to build a module params (dict, optional): extra specific parameters for this module that take priority over those in ``shared_params`` + + Returns: + the build module """ instance, _ = instantiate( builder=builder, @@ -252,8 +257,9 @@ def append_from_parameters( all_args=shared_params, ) self.append(name, instance) - return + return instance + @torch.jit.unused def insert( self, name: str, @@ -311,6 +317,7 @@ def insert( return + @torch.jit.unused def insert_from_parameters( self, shared_params: Mapping, @@ -319,7 +326,7 @@ def insert_from_parameters( params: Dict[str, Any] = {}, after: Optional[str] = None, before: Optional[str] = None, - ) -> None: + ) -> GraphModuleMixin: r"""Build a module from parameters and insert it after ``after``. Args: @@ -329,6 +336,9 @@ def insert_from_parameters( params (dict, optional): extra specific parameters for this module that take priority over those in ``shared_params`` after: the name of the module to insert after before: the name of the module to insert before + + Returns: + the inserted module """ if (before is None) is (after is None): raise ValueError("Only one of before or after argument needs to be defined") @@ -347,7 +357,7 @@ def insert_from_parameters( all_args=shared_params, ) self.insert(after=after, before=before, name=name, module=instance) - return + return instance # Copied from https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential # with type annotations added diff --git a/nequip/nn/_graph_model.py b/nequip/nn/_graph_model.py new file mode 100644 index 00000000..d33ad378 --- /dev/null +++ b/nequip/nn/_graph_model.py @@ -0,0 +1,119 @@ +from typing import List, Dict, Any, Optional + +import torch + +from e3nn.util._argtools import _get_device + +from nequip.data import AtomicDataDict + +from ._graph_mixin import GraphModuleMixin +from ._rescale import RescaleOutput + + +class GraphModel(GraphModuleMixin, torch.nn.Module): + """Top-level module for any complete `nequip` model. + + Manages top-level rescaling, dtypes, and more. + + Args: + + """ + + model_dtype: torch.dtype + model_input_fields: List[str] + + _num_rescale_layers: int + + def __init__( + self, + model: GraphModuleMixin, + model_dtype: Optional[torch.dtype] = None, + model_input_fields: Dict[str, Any] = {}, + ) -> None: + super().__init__() + irreps_in = { + # Things that always make sense as inputs: + AtomicDataDict.POSITIONS_KEY: "1o", + AtomicDataDict.EDGE_INDEX_KEY: None, + AtomicDataDict.EDGE_CELL_SHIFT_KEY: None, + AtomicDataDict.CELL_KEY: "1o", # 3 of them, but still + AtomicDataDict.BATCH_KEY: None, + AtomicDataDict.BATCH_PTR_KEY: None, + AtomicDataDict.ATOM_TYPE_KEY: None, + } + model_input_fields = AtomicDataDict._fix_irreps_dict(model_input_fields) + assert len(set(irreps_in.keys()).intersection(model_input_fields.keys())) == 0 + irreps_in.update(model_input_fields) + self._init_irreps(irreps_in=irreps_in, irreps_out=model.irreps_out) + for k, irreps in model.irreps_in.items(): + if self.irreps_in.get(k, None) != irreps: + raise RuntimeError( + f"Model has `{k}` in its irreps_in with irreps `{irreps}`, but `{k}` is missing from/has inconsistent irreps in model_input_fields of `{self.irreps_in.get(k, 'missing')}`" + ) + self.model = model + self.model_dtype = ( + model_dtype if model_dtype is not None else torch.get_default_dtype() + ) + self.model_input_fields = list(self.irreps_in.keys()) + + self._num_rescale_layers = 0 + outer_layer = self.model + while isinstance(outer_layer, RescaleOutput): + self._num_rescale_layers += 1 + outer_layer = outer_layer.model + + # == Rescaling == + @torch.jit.unused + def all_RescaleOutputs(self) -> List[RescaleOutput]: + """All ``RescaleOutput``s wrapping the model, in evaluation order.""" + if self._num_rescale_layers == 0: + return [] + # we know there's at least one + out = [self.model] + for _ in range(self._num_rescale_layers - 1): + out.append(out[-1].model) + # we iterated outermost to innermost, which is opposite of evaluation order + assert len(out) == self._num_rescale_layers + return out[::-1] + + @torch.jit.unused + def unscale( + self, data: AtomicDataDict.Type, force_process: bool = False + ) -> AtomicDataDict.Type: + data_unscaled = data.copy() + # we need to unscale from the outside-in: + for layer in self.all_RescaleOutputs()[::-1]: + data_unscaled = layer.unscale(data_unscaled, force_process=force_process) + return data_unscaled + + @torch.jit.unused + def scale( + self, data: AtomicDataDict.Type, force_process: bool = False + ) -> AtomicDataDict.Type: + data_scaled = data.copy() + # we need to scale from the inside out: + for layer in self.all_RescaleOutputs(): + data_scaled = layer.scale(data_scaled, force_process=force_process) + return data_scaled + + # == Inference == + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + # restrict the input data to allowed keys, and cast to model_dtype + # this also prevents the model from direclty using the dict from the outside, + # preventing weird pass-by-reference bugs + new_data: AtomicDataDict.Type = {} + for k, v in data.items(): + if k in self.model_input_fields: + if v.is_floating_point(): + v = v.to(dtype=self.model_dtype) + new_data[k] = v + # run the model + data = self.model(new_data) + return data + + # == Helpers == + + @torch.jit.unused + def get_device(self) -> torch.device: + return _get_device(self) diff --git a/nequip/nn/_interaction_block.py b/nequip/nn/_interaction_block.py index 99b3acc6..f3164709 100644 --- a/nequip/nn/_interaction_block.py +++ b/nequip/nn/_interaction_block.py @@ -26,7 +26,7 @@ def __init__( invariant_neurons=8, avg_num_neighbors=None, use_sc=True, - nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp"}, + nonlinearity_scalars: Dict[int, Callable] = {"e": "silu"}, ) -> None: """ InteractionBlock. @@ -168,12 +168,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: edge_features = self.tp( x[edge_src], data[AtomicDataDict.EDGE_ATTRS_KEY], weight ) - x = scatter(edge_features, edge_dst, dim=0, dim_size=len(x)) - + # divide first for numerics, scatter is linear # Necessary to get TorchScript to be able to type infer when its not None avg_num_neigh: Optional[float] = self.avg_num_neighbors if avg_num_neigh is not None: - x = x.div(avg_num_neigh**0.5) + edge_features = edge_features.div(avg_num_neigh**0.5) + # now scatter down + x = scatter(edge_features, edge_dst, dim=0, dim_size=len(x)) x = self.linear_2(x) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 8bea7096..1828ab56 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -1,4 +1,4 @@ -from typing import Sequence, List, Union +from typing import Sequence, List, Union, Optional import torch @@ -6,12 +6,15 @@ from nequip.data import AtomicDataDict from nequip.nn import GraphModuleMixin +from nequip.utils import dtype_from_name @compile_mode("script") class RescaleOutput(GraphModuleMixin, torch.nn.Module): """Wrap a model and rescale its outputs when in ``eval()`` mode. + Note that scaling/shifting is always done (casting into) ``default_dtype``, even if ``model_dtype`` is lower precision. + Args: model : GraphModuleMixin The model whose outputs are to be rescaled. @@ -19,8 +22,6 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): Which fields to rescale. shift_keys : list of keys, default [] Which fields to shift after rescaling. - related_scale_keys: list of keys that could be contingent to this rescale - related_shift_keys: list of keys that could be contingent to this rescale scale_by : floating or Tensor, default 1. The scaling factor by which to multiply fields in ``scale``. shift_by : floating or Tensor, default 0. @@ -31,25 +32,25 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): scale_keys: List[str] shift_keys: List[str] - related_scale_keys: List[str] - related_shift_keys: List[str] scale_trainble: bool rescale_trainable: bool + _all_keys: List[str] has_scale: bool has_shift: bool + default_dtype: torch.dtype + def __init__( self, model: GraphModuleMixin, scale_keys: Union[Sequence[str], str] = [], shift_keys: Union[Sequence[str], str] = [], - related_shift_keys: Union[Sequence[str], str] = [], - related_scale_keys: Union[Sequence[str], str] = [], scale_by=None, shift_by=None, shift_trainable: bool = False, scale_trainable: bool = False, + default_dtype: Optional[str] = None, irreps_in: dict = {}, ): super().__init__() @@ -81,13 +82,16 @@ def __init__( self.scale_keys = list(scale_keys) self.shift_keys = list(shift_keys) - self.related_scale_keys = list(set(related_scale_keys).union(scale_keys)) - self.related_shift_keys = list(set(related_shift_keys).union(shift_keys)) + self._all_keys = list(all_keys) + + self.default_dtype = dtype_from_name( + torch.get_default_dtype() if default_dtype is None else default_dtype + ) self.has_scale = scale_by is not None self.scale_trainble = scale_trainable if self.has_scale: - scale_by = torch.as_tensor(scale_by) + scale_by = torch.as_tensor(scale_by, dtype=self.default_dtype) if self.scale_trainble: self.scale_by = torch.nn.Parameter(scale_by) else: @@ -103,7 +107,7 @@ def __init__( self.has_shift = shift_by is not None self.rescale_trainable = shift_trainable if self.has_shift: - shift_by = torch.as_tensor(shift_by) + shift_by = torch.as_tensor(shift_by, dtype=self.default_dtype) if self.rescale_trainable: self.shift_by = torch.nn.Parameter(shift_by) else: @@ -139,16 +143,29 @@ def get_inner_model(self): def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data = self.model(data) if self.training: - return data + # no scaling, but still need to promote for consistent dtype behavior + # this is hopefully a no-op in most circumstances due to a + # preceeding PerSpecies rescale promoting to default_dtype anyway: + for field in self._all_keys: + data[field] = data[field].to(dtype=self.default_dtype) else: # Scale then shift + # * and + promote dtypes by default, but not when the other + # operand is a scalar, which `scale/shift_by` are. + # We solve this by expanding `scale/shift_by` to tensors + # This is free and doesn't allocate new memory on CUDA: + # https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand + # confirmed in PyTorch slack + # https://pytorch.slack.com/archives/C3PDTEV8E/p1671652283801129 if self.has_scale: for field in self.scale_keys: - data[field] = data[field] * self.scale_by + v = data[field] + data[field] = v * self.scale_by.expand(v.shape) if self.has_shift: for field in self.shift_keys: - data[field] = data[field] + self.shift_by - return data + v = data[field] + data[field] = v + self.shift_by.expand(v.shape) + return data @torch.jit.export def scale( diff --git a/nequip/nn/embedding/__init__.py b/nequip/nn/embedding/__init__.py index dfc9b710..9a0c0d86 100644 --- a/nequip/nn/embedding/__init__.py +++ b/nequip/nn/embedding/__init__.py @@ -1,4 +1,13 @@ from ._one_hot import OneHotAtomEncoding -from ._edge import SphericalHarmonicEdgeAttrs, RadialBasisEdgeEncoding +from ._edge import ( + SphericalHarmonicEdgeAttrs, + RadialBasisEdgeEncoding, + AddRadialCutoffToData, +) -__all__ = [OneHotAtomEncoding, SphericalHarmonicEdgeAttrs, RadialBasisEdgeEncoding] +__all__ = [ + OneHotAtomEncoding, + SphericalHarmonicEdgeAttrs, + RadialBasisEdgeEncoding, + AddRadialCutoffToData, +] diff --git a/nequip/nn/embedding/_edge.py b/nequip/nn/embedding/_edge.py index 3705ae35..4585fec7 100644 --- a/nequip/nn/embedding/_edge.py +++ b/nequip/nn/embedding/_edge.py @@ -76,14 +76,39 @@ def __init__( self.out_field = out_field self._init_irreps( irreps_in=irreps_in, - irreps_out={self.out_field: o3.Irreps([(self.basis.num_basis, (0, 1))])}, + irreps_out={ + self.out_field: o3.Irreps([(self.basis.num_basis, (0, 1))]), + AtomicDataDict.EDGE_CUTOFF_KEY: "0e", + }, ) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) edge_length = data[AtomicDataDict.EDGE_LENGTH_KEY] - edge_length_embedded = ( - self.basis(edge_length) * self.cutoff(edge_length)[:, None] - ) + cutoff = self.cutoff(edge_length).unsqueeze(-1) + edge_length_embedded = self.basis(edge_length) * cutoff data[self.out_field] = edge_length_embedded + data[AtomicDataDict.EDGE_CUTOFF_KEY] = cutoff + return data + + +@compile_mode("script") +class AddRadialCutoffToData(GraphModuleMixin, torch.nn.Module): + def __init__( + self, + cutoff=PolynomialCutoff, + cutoff_kwargs={}, + irreps_in=None, + ): + super().__init__() + self.cutoff = cutoff(**cutoff_kwargs) + self._init_irreps( + irreps_in=irreps_in, irreps_out={AtomicDataDict.EDGE_CUTOFF_KEY: "0e"} + ) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) + edge_length = data[AtomicDataDict.EDGE_LENGTH_KEY] + cutoff = self.cutoff(edge_length).unsqueeze(-1) + data[AtomicDataDict.EDGE_CUTOFF_KEY] = cutoff return data diff --git a/nequip/nn/pair_potential.py b/nequip/nn/pair_potential.py new file mode 100644 index 00000000..f448afc3 --- /dev/null +++ b/nequip/nn/pair_potential.py @@ -0,0 +1,350 @@ +from typing import Union, Optional, Dict, List + +import torch +from torch_runstats.scatter import scatter + +from e3nn.util.jit import compile_mode + +import ase.data + +from nequip.data import AtomicDataDict +from nequip.nn import GraphModuleMixin, RescaleOutput + + +@torch.jit.script +def _param(param, index1, index2): + if param.ndim == 2: + # make it symmetric + param = param.triu() + param.triu(1).transpose(-1, -2) + # get for each atom pair + param = torch.index_select(param.view(-1), 0, index1 * param.shape[0] + index2) + # make it positive + param = param.relu() # TODO: better way? + return param + + +@compile_mode("script") +class LennardJones(GraphModuleMixin, torch.nn.Module): + """Lennard-Jones and related pair potentials.""" + + lj_style: str + exponent: float + + def __init__( + self, + num_types: int, + lj_sigma: Union[torch.Tensor, float], + lj_delta: Union[torch.Tensor, float] = 0, + lj_epsilon: Optional[Union[torch.Tensor, float]] = None, + lj_sigma_trainable: bool = False, + lj_delta_trainable: bool = False, + lj_epsilon_trainable: bool = False, + lj_exponent: Optional[float] = None, + lj_per_type: bool = True, + lj_style: str = "lj", + irreps_in=None, + ) -> None: + super().__init__() + self._init_irreps( + irreps_in=irreps_in, irreps_out={AtomicDataDict.PER_ATOM_ENERGY_KEY: "0e"} + ) + assert lj_style in ("lj", "lj_repulsive_only", "repulsive") + self.lj_style = lj_style + + for param, (value, trainable) in { + "epsilon": (lj_epsilon, lj_epsilon_trainable), + "sigma": (lj_sigma, lj_sigma_trainable), + "delta": (lj_delta, lj_delta_trainable), + }.items(): + if value is None: + self.register_buffer(param, torch.Tensor()) # torchscript + continue + value = torch.as_tensor(value, dtype=torch.get_default_dtype()) + if value.ndim == 0 and lj_per_type: + # one scalar for all pair types + value = ( + torch.ones( + num_types, num_types, device=value.device, dtype=value.dtype + ) + * value + ) + elif value.ndim == 2: + assert lj_per_type + # one per pair type, check symmetric + assert value.shape == (num_types, num_types) + # per-species square, make sure symmetric + assert torch.equal(value, value.T) + value = torch.triu(value) + else: + raise ValueError + setattr(self, param, torch.nn.Parameter(value, requires_grad=trainable)) + + if lj_exponent is None: + lj_exponent = 6.0 + self.exponent = lj_exponent + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) + edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0] + atom_types = data[AtomicDataDict.ATOM_TYPE_KEY] + edge_len = data[AtomicDataDict.EDGE_LENGTH_KEY].unsqueeze(-1) + edge_types = torch.index_select( + atom_types, 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1) + ).view(2, -1) + index1 = edge_types[0] + index2 = edge_types[1] + + sigma = _param(self.sigma, index1, index2) + delta = _param(self.delta, index1, index2) + epsilon = _param(self.epsilon, index1, index2) + + if self.lj_style == "repulsive": + # 0.5 to assign half and half the energy to each side of the interaction + lj_eng = 0.5 * epsilon * ((sigma * (edge_len - delta)) ** -self.exponent) + else: + lj_eng = (sigma / (edge_len - delta)) ** self.exponent + lj_eng = torch.neg(lj_eng) + lj_eng = lj_eng + lj_eng.square() + # 2.0 because we do the slightly symmetric thing and let + # ij and ji each contribute half of the LJ energy of the pair + # this avoids indexing out certain edges in the general case where + # the edges are not ordered. + lj_eng = (2.0 * epsilon) * lj_eng + + if self.lj_style == "lj_repulsive_only": + # if taking only the repulsive part, shift up so the minima is at eng=0 + lj_eng = lj_eng + epsilon + # this is continuous at the minima, and we mask out everything greater + # TODO: this is probably broken with NaNs at delta + lj_eng = lj_eng * (edge_len < (2 ** (1.0 / self.exponent) + delta)) + + # apply the cutoff for smoothness + lj_eng = lj_eng * data[AtomicDataDict.EDGE_CUTOFF_KEY] + + # sum edge LJ energies onto atoms + atomic_eng = scatter( + lj_eng, + edge_center, + dim=0, + dim_size=len(data[AtomicDataDict.POSITIONS_KEY]), + ) + if AtomicDataDict.PER_ATOM_ENERGY_KEY in data: + atomic_eng = atomic_eng + data[AtomicDataDict.PER_ATOM_ENERGY_KEY] + data[AtomicDataDict.PER_ATOM_ENERGY_KEY] = atomic_eng + return data + + def __repr__(self) -> str: + def _f(e): + e = e.data + if e.ndim == 0: + return f"{e:.6f}" + elif e.ndim == 2: + return f"{e}" + + return f"PairPotential(lj_style={self.lj_style} | σ={_f(self.sigma)} δ={_f(self.delta)} ε={_f(self.epsilon)} exp={self.exponent:.1f})" + + def update_for_rescale(self, rescale_module: RescaleOutput): + if AtomicDataDict.PER_ATOM_ENERGY_KEY not in rescale_module.scale_keys: + return + if not rescale_module.has_scale: + return + with torch.no_grad(): + # Our energy will be scaled by scale_by later, so we have to divide here to cancel out: + self.epsilon.copy_(self.epsilon / rescale_module.scale_by.item()) + + +@compile_mode("script") +class SimpleLennardJones(GraphModuleMixin, torch.nn.Module): + """Simple Lennard-Jones.""" + + lj_sigma: float + lj_epsilon: float + lj_use_cutoff: bool + + def __init__( + self, + lj_sigma: float, + lj_epsilon: float, + lj_use_cutoff: bool = False, + irreps_in=None, + ) -> None: + super().__init__() + self._init_irreps( + irreps_in=irreps_in, irreps_out={AtomicDataDict.PER_ATOM_ENERGY_KEY: "0e"} + ) + self.lj_sigma, self.lj_epsilon, self.lj_use_cutoff = ( + lj_sigma, + lj_epsilon, + lj_use_cutoff, + ) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) + edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0] + edge_len = data[AtomicDataDict.EDGE_LENGTH_KEY].unsqueeze(-1) + + lj_eng = (self.lj_sigma / edge_len) ** 6.0 + lj_eng = lj_eng.square() - lj_eng + lj_eng = 2 * self.lj_epsilon * lj_eng + + if self.lj_use_cutoff: + # apply the cutoff for smoothness + lj_eng = lj_eng * data[AtomicDataDict.EDGE_CUTOFF_KEY] + + # sum edge LJ energies onto atoms + atomic_eng = scatter( + lj_eng, + edge_center, + dim=0, + dim_size=len(data[AtomicDataDict.POSITIONS_KEY]), + ) + if AtomicDataDict.PER_ATOM_ENERGY_KEY in data: + atomic_eng = atomic_eng + data[AtomicDataDict.PER_ATOM_ENERGY_KEY] + data[AtomicDataDict.PER_ATOM_ENERGY_KEY] = atomic_eng + return data + + def update_for_rescale(self, rescale_module: RescaleOutput): + if AtomicDataDict.PER_ATOM_ENERGY_KEY not in rescale_module.scale_keys: + return + if not rescale_module.has_scale: + return + # Our energy will be scaled by scale_by later, so we have to divide here to cancel out: + self.lj_epsilon /= rescale_module.scale_by.item() + + +@torch.jit.script +def _zbl( + Z: torch.Tensor, + r: torch.Tensor, + atom_types: torch.Tensor, + edge_index: torch.Tensor, + qqr2exesquare: float, +) -> torch.Tensor: + # from LAMMPS pair_zbl_const.h + pzbl: float = 0.23 + a0: float = 0.46850 + c1: float = 0.02817 + c2: float = 0.28022 + c3: float = 0.50986 + c4: float = 0.18175 + d1: float = -0.20162 + d2: float = -0.40290 + d3: float = -0.94229 + d4: float = -3.19980 + # compute + edge_types = torch.index_select(atom_types, 0, edge_index.reshape(-1)) + Z = torch.index_select(Z, 0, edge_types.view(-1)).view( + 2, -1 + ) # [center/neigh, n_edge] + Zi, Zj = Z[0], Z[1] + del edge_types, Z + x = ((torch.pow(Zi, pzbl) + torch.pow(Zj, pzbl)) * r) / a0 + psi = ( + c1 * (d1 * x).exp() + + c2 * (d2 * x).exp() + + c3 * (d3 * x).exp() + + c4 * (d4 * x).exp() + ) + eng = qqr2exesquare * ((Zi * Zj) / r) * psi + return eng + + +@compile_mode("script") +class ZBL(GraphModuleMixin, torch.nn.Module): + """Add a ZBL pair potential to the edge energy. + + Args: + units (str): what units the model/data are in using LAMMPS names. + """ + + num_types: int + + def __init__( + self, + num_types: int, + units: str, + type_to_chemical_symbol: Optional[Dict[int, str]] = None, + irreps_in=None, + ): + super().__init__() + self._init_irreps( + irreps_in=irreps_in, irreps_out={AtomicDataDict.PER_ATOM_ENERGY_KEY: "0e"} + ) + if type_to_chemical_symbol is not None: + assert set(type_to_chemical_symbol.keys()) == set(range(num_types)) + atomic_numbers: List[int] = [ + ase.data.atomic_numbers[type_to_chemical_symbol[type_i]] + for type_i in range(num_types) + ] + if min(atomic_numbers) < 1: + raise ValueError( + f"Your chemical symbols don't seem valid (minimum atomic number is {min(atomic_numbers)} < 1); did you try to use fake chemical symbols for arbitrary atom types? If so, instead provide atom_types directly in your dataset and specify `type_names` and `type_to_chemical_symbol` in your config. `type_to_chemical_symbol` then tells ZBL what atomic numbers to use for the various atom types in your system." + ) + else: + raise RuntimeError( + "Either chemical_symbol_to_type or type_to_chemical_symbol is required." + ) + assert len(atomic_numbers) == num_types + # LAMMPS note on units: + # > The numerical values of the exponential decay constants in the + # > screening function depend on the unit of distance. In the above + # > equation they are given for units of Angstroms. LAMMPS will + # > automatically convert these values to the distance unit of the + # > specified LAMMPS units setting. The values of Z should always be + # > given as multiples of a proton’s charge, e.g. 29.0 for copper. + # So, we store the atomic numbers directly. + self.register_buffer( + "atomic_numbers", + torch.as_tensor(atomic_numbers, dtype=torch.get_default_dtype()), + ) + # And we have to convert our value of prefector into the model's physical units + # Here, prefactor is (electron charge)^2 / (4 * pi * electrical permisivity of vacuum) + # we have a value for that in eV and Angstrom + # See https://github.com/lammps/lammps/blob/c415385ab4b0983fa1c72f9e92a09a8ed7eebe4a/src/update.cpp#L187 for values from LAMMPS + # LAMMPS uses `force->qqr2e * force->qelectron * force->qelectron` + # Make it a buffer so rescalings are persistent, it still acts as a scalar Tensor + self.register_buffer( + "_qqr2exesquare", + torch.as_tensor( + {"metal": 14.399645 * (1.0) ** 2, "real": 332.06371 * (1.0) ** 2}[ + units + ], + dtype=torch.float64, + ) + * 0.5, # Put half the energy on each of ij, ji + ) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) + edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0] + + zbl_edge_eng = _zbl( + Z=self.atomic_numbers, + r=data[AtomicDataDict.EDGE_LENGTH_KEY], + atom_types=data[AtomicDataDict.ATOM_TYPE_KEY], + edge_index=data[AtomicDataDict.EDGE_INDEX_KEY], + qqr2exesquare=self._qqr2exesquare, + ).unsqueeze(-1) + # apply cutoff + zbl_edge_eng = zbl_edge_eng * data[AtomicDataDict.EDGE_CUTOFF_KEY] + atomic_eng = scatter( + zbl_edge_eng, + edge_center, + dim=0, + dim_size=len(data[AtomicDataDict.POSITIONS_KEY]), + ) + if AtomicDataDict.PER_ATOM_ENERGY_KEY in data: + atomic_eng = atomic_eng + data[AtomicDataDict.PER_ATOM_ENERGY_KEY] + data[AtomicDataDict.PER_ATOM_ENERGY_KEY] = atomic_eng + return data + + def update_for_rescale(self, rescale_module: RescaleOutput): + if AtomicDataDict.PER_ATOM_ENERGY_KEY not in rescale_module.scale_keys: + return + if not rescale_module.has_scale: + return + # Our energy will be scaled by scale_by later, so we have to divide here to cancel out: + self._qqr2exesquare /= rescale_module.scale_by.item() + + +__all__ = [LennardJones, ZBL] diff --git a/nequip/scripts/benchmark.py b/nequip/scripts/benchmark.py index 1deb0de2..2ac1e607 100644 --- a/nequip/scripts/benchmark.py +++ b/nequip/scripts/benchmark.py @@ -7,6 +7,7 @@ import sys import pdb import traceback +import pickle import torch from torch.utils.benchmark import Timer, Measurement @@ -57,19 +58,13 @@ def main(args=None): "-n", help="Number of trials.", type=int, - default=30, + default=None, ) parser.add_argument( "--n-data", help="Number of frames to use.", type=int, - default=1, - ) - parser.add_argument( - "--timestep", - help="MD timestep for ns/day esimation, in fs. Defauts to 1fs.", - type=float, - default=1, + default=2, ) parser.add_argument( "--no-compile", @@ -116,20 +111,22 @@ def main(args=None): dataset = dataset_from_config(config) dataset_time = time.time() - dataset_time print(f" loading dataset took {dataset_time:.4f}s") + print( + f" loaded dataset of size {len(dataset)} and sampled --n-data={args.n_data} frames" + ) dataset_rng = torch.Generator() dataset_rng.manual_seed(config.get("dataset_seed", config.get("seed", 12345))) + dataset = dataset.index_select( + torch.randperm(len(dataset), generator=dataset_rng)[: args.n_data] + ) datas_list = [ - AtomicData.to_AtomicDataDict(dataset[i].to(device)) - for i in torch.randperm(len(dataset), generator=dataset_rng)[: args.n_data] + AtomicData.to_AtomicDataDict(dataset[i].to(device)) for i in range(args.n_data) ] n_atom: int = len(datas_list[0]["pos"]) if not all(len(d["pos"]) == n_atom for d in datas_list): raise NotImplementedError( "nequip-benchmark does not currently handle benchmarking on data frames with variable number of atoms" ) - print( - f" loaded dataset of size {len(dataset)} and sampled --n-data={args.n_data} frames" - ) # print some dataset information print(" benchmark frames statistics:") print(f" number of atoms: {n_atom}") @@ -157,6 +154,8 @@ def main(args=None): if args.n == 0: print("Got -n 0, so quitting without running benchmark.") return + elif args.n is None: + args.n = 5 if args.profile else 30 # Load model: if args.model is None: @@ -239,8 +238,11 @@ def trace_handler(p): on_trace_ready=trace_handler, ) as p: for _ in range(1 + warmup + args.n): - model(next(datas).copy()) + out = model(next(datas).copy()) + out[AtomicDataDict.TOTAL_ENERGY_KEY].item() p.step() + + print(p.key_averages().table(sort_by="cuda_time_total", row_limit=100)) elif args.pdb: print("Running model under debugger...") try: @@ -270,6 +272,14 @@ def trace_handler(p): ) del errstr else: + if args.memory_summary and torch.cuda.is_available(): + torch.cuda.memory._record_memory_history( + True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + # record stack information for the trace events + trace_alloc_record_context=True, + ) print("Warmup...") warmup_time = time.time() for _ in range(warmup): @@ -278,22 +288,34 @@ def trace_handler(p): print(f" {warmup} calls of warmup took {warmup_time:.4f}s") print("Benchmarking...") + # just time t = Timer( - stmt="model(next(datas).copy())", globals={"model": model, "datas": datas} + stmt="model(next(datas).copy())['total_energy'].item()", + globals={"model": model, "datas": datas}, ) perloop: Measurement = t.timeit(args.n) if args.memory_summary and torch.cuda.is_available(): print("Memory usage summary:") print(torch.cuda.memory_summary()) + snapshot = torch.cuda.memory._snapshot() + + with open("snapshot.pickle", "wb") as f: + pickle.dump(snapshot, f) print(" -- Results --") print( f"PLEASE NOTE: these are speeds for the MODEL, evaluated on --n-data={args.n_data} configurations kept in memory." ) print( - " \\_ MD itself, memory copies, and other overhead will affect real-world performance." + "A variety of factors affect the performance in real molecular dynamics calculations:" + ) + print( + "!!! Molecular dynamics speeds should be measured in LAMMPS; speeds from nequip-benchmark should only be used as an estimate of RELATIVE speed among different hyperparameters." + ) + print( + "Please further note that relative speed ordering of hyperparameters is NOT NECESSARILY CONSISTENT across different classes of GPUs (i.e. A100 vs V100 vs consumer) or GPUs vs CPUs." ) print() trim_time = trim_sigfig(perloop.times[0], perloop.significant_figures) @@ -302,19 +324,6 @@ def trace_handler(p): trim_time / time_scale ) print(f"The average call took {time_str}{time_unit}") - print( - "Assuming linear scaling — which is ALMOST NEVER true in practice, especially on GPU —" - ) - per_atom_time = trim_time / n_atom - time_unit_per, time_scale_per = select_unit(per_atom_time) - print( - f" \\_ this comes out to {per_atom_time/time_scale_per:g} {time_unit_per}/atom/call" - ) - ns_day = (86400.0 / trim_time) * args.timestep * 1e-6 - # day in s^ s/step^ ^ fs / step ^ ns / fs - print( - f"For this system, at a {args.timestep:.2f}fs timestep, this comes out to {ns_day:.2f} ns/day" - ) if __name__ == "__main__": diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index 394c0005..a0772df9 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -1,15 +1,17 @@ import sys if sys.version_info[1] >= 8: - from typing import Final + from typing import Final, Optional else: - from typing_extensions import Final + from typing_extensions import Final, Optional from typing import Tuple, Dict, Union import argparse import pathlib import logging import yaml import itertools +import packaging.version +import warnings # This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch. # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. @@ -17,15 +19,14 @@ import torch -import ase.data - from e3nn.util.jit import script from nequip.model import model_from_config -from nequip.train import Trainer +from nequip.data import dataset_from_config from nequip.utils import Config -from nequip.utils.versions import check_code_version, get_config_code_versions +from nequip.utils.versions import check_code_version, get_current_code_versions from nequip.scripts.train import default_config +from nequip.utils.misc import dtype_to_name from nequip.utils._global_options import _set_global_options CONFIG_KEY: Final[str] = "config" @@ -39,6 +40,8 @@ JIT_BAILOUT_KEY: Final[str] = "_jit_bailout_depth" JIT_FUSION_STRATEGY: Final[str] = "_jit_fusion_strategy" TF32_KEY: Final[str] = "allow_tf32" +DEFAULT_DTYPE_KEY: Final[str] = "default_dtype" +MODEL_DTYPE_KEY: Final[str] = "model_dtype" _ALL_METADATA_KEYS = [ CONFIG_KEY, @@ -51,9 +54,29 @@ JIT_BAILOUT_KEY, JIT_FUSION_STRATEGY, TF32_KEY, + DEFAULT_DTYPE_KEY, + MODEL_DTYPE_KEY, ] +def _register_metadata_key(key: str) -> None: + _ALL_METADATA_KEYS.append(key) + + +_current_metadata: Optional[dict] = None + + +def _set_deploy_metadata(key: str, value) -> None: + # TODO: not thread safe but who cares? + global _current_metadata + if _current_metadata is None: + pass # not deploying right now + elif key in _current_metadata: + raise RuntimeError(f"{key} already set in the deployment metadata") + else: + _current_metadata[key] = value + + def _compile_for_deploy(model): model.eval() @@ -100,11 +123,28 @@ def load_deployed_model( model = torch.jit.freeze(model) # Everything we store right now is ASCII, so decode for printing metadata = {k: v.decode("ascii") for k, v in metadata.items()} + # Update metadata for backward compatibility + if metadata[DEFAULT_DTYPE_KEY] == "": + # Default and model go together + assert metadata[MODEL_DTYPE_KEY] == "" + # If there isn't a dtype, it should be older than 0.6.0: + assert packaging.version.parse( + metadata[NEQUIP_VERSION_KEY] + ) < packaging.version.parse("0.6.0") + # i.e. no value due to L85 above + # The old pre-0.6.0 defaults: + metadata[DEFAULT_DTYPE_KEY] = "float32" + metadata[MODEL_DTYPE_KEY] = "float32" + warnings.warn( + "Models deployed before v0.6.0 don't contain information about their default_dtype or model_dtype; assuming the old default of float32 for both, but this might not be right if you had explicitly set default_dtype=float64." + ) + # Set up global settings: assert set_global_options in (True, False, "warn") if set_global_options: global_config_dict = {} global_config_dict["allow_tf32"] = bool(int(metadata[TF32_KEY])) + global_config_dict["default_dtype"] = str(metadata[DEFAULT_DTYPE_KEY]) # JIT strategy strategy = metadata.get(JIT_FUSION_STRATEGY, "") if strategy != "": @@ -164,6 +204,25 @@ def main(args=None): help="Path to a working directory from a training session to deploy.", type=pathlib.Path, ) + build_parser.add_argument( + "--checkpoint", + help="Which model checkpoint from --train-dir to deploy. Defaults to `best_model.pth`. If --train-dir is provided, this is a relative path; if --model is provided instead, this is an absolute path.", + type=str, + default=None, + ) + build_parser.add_argument( + "--override", + help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.", + type=str, + default=None, + ) + build_parser.add_argument( + "--using-dataset", + help="Allow model builders to use a dataset during deployment. By default uses the training dataset, but can point to a YAML file for another dataset.", + type=pathlib.Path, + const=True, + nargs="?", + ) build_parser.add_argument( "out_file", help="Output file for deployed model.", @@ -196,10 +255,15 @@ def main(args=None): logging.debug(f"Model had config:\n{config}") elif args.command == "build": + state_dict = None if args.model and args.train_dir: raise ValueError("--model and --train-dir cannot both be specified.") + checkpoint_file = args.checkpoint if args.train_dir is not None: - logging.info("Loading best_model from training session...") + if checkpoint_file is None: + checkpoint_file = "best_model.pth" + logging.info(f"Loading {checkpoint_file} from training session...") + checkpoint_file = str(args.train_dir / "best_model.pth") config = Config.from_file(str(args.train_dir / "config.yaml")) elif args.model is not None: logging.info("Building model from config...") @@ -207,18 +271,45 @@ def main(args=None): else: raise ValueError("one of --train-dir or --model must be given") + # Set override options before _set_global_options so that things like allow_tf32 are correctly handled + if args.override is not None: + override_options = yaml.load(args.override, Loader=yaml.Loader) + assert isinstance( + override_options, dict + ), "--override's YAML string must define a dictionary of top-level options" + overridden_keys = set(config.keys()).intersection(override_options.keys()) + set_keys = set(override_options.keys()) - set(overridden_keys) + logging.info( + f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}" + ) + config.update(override_options) + del override_options, overridden_keys, set_keys + _set_global_options(config) check_code_version(config) # -- load model -- - if args.train_dir is not None: - model, _ = Trainer.load_model_from_training_session( - args.train_dir, model_name="best_model.pth", device="cpu" + # figure out first if a dataset is involved + dataset = None + if args.using_dataset: + dataset_config = config + if args.using_dataset is not True: + dataset_config = Config.from_file(str(args.using_dataset)) + dataset = dataset_from_config(dataset_config) + if args.using_dataset is True: + # we're using the one from training config + # downselect to training set + dataset = dataset.index_select(config.train_idcs) + # build the actual model] + # reset the global metadata dict so that model builders can fill it: + global _current_metadata + _current_metadata = {} + model = model_from_config(config, dataset=dataset, deploy=True) + if checkpoint_file is not None: + state_dict = torch.load( + str(args.train_dir / "best_model.pth"), map_location="cpu" ) - elif args.model is not None: - model = model_from_config(config, deploy=True) - else: - raise AssertionError + model.load_state_dict(state_dict, strict=True) # -- compile -- model = _compile_for_deploy(model) @@ -226,7 +317,7 @@ def main(args=None): # Deploy metadata: dict = {} - code_versions, code_commits = get_config_code_versions(config) + code_versions, code_commits = get_current_code_versions(config) for code, version in code_versions.items(): metadata[code + "_version"] = version if len(code_commits) > 0: @@ -235,29 +326,33 @@ def main(args=None): ) metadata[R_MAX_KEY] = str(float(config["r_max"])) - if "allowed_species" in config: - # This is from before the atomic number updates - n_species = len(config["allowed_species"]) - type_names = { - type: ase.data.chemical_symbols[atomic_num] - for type, atomic_num in enumerate(config["allowed_species"]) - } - else: - # The new atomic number setup - n_species = str(config["num_types"]) - type_names = config["type_names"] + n_species = str(config["num_types"]) + type_names = config["type_names"] metadata[N_SPECIES_KEY] = str(n_species) metadata[TYPE_NAMES_KEY] = " ".join(type_names) metadata[JIT_BAILOUT_KEY] = str(config[JIT_BAILOUT_KEY]) - if int(torch.__version__.split(".")[1]) >= 11 and JIT_FUSION_STRATEGY in config: + if ( + packaging.version.parse(torch.__version__) + >= packaging.version.parse("1.11") + and JIT_FUSION_STRATEGY in config + ): metadata[JIT_FUSION_STRATEGY] = ";".join( "%s,%i" % e for e in config[JIT_FUSION_STRATEGY] ) metadata[TF32_KEY] = str(int(config["allow_tf32"])) - metadata[CONFIG_KEY] = yaml.dump(dict(config)) + metadata[DEFAULT_DTYPE_KEY] = dtype_to_name(config["default_dtype"]) + metadata[MODEL_DTYPE_KEY] = dtype_to_name(config["model_dtype"]) + metadata[CONFIG_KEY] = yaml.dump(Config.as_dict(config)) + + for k, v in _current_metadata.items(): + if k in metadata: + raise RuntimeError(f"Custom deploy key {k} was already set") + metadata[k] = v + _current_metadata = None metadata = {k: v.encode("ascii") for k, v in metadata.items()} + torch.jit.save(model, args.out_file, _extra_files=metadata) else: raise ValueError diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index f7dfa12b..20382eef 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple, List import sys import argparse import logging @@ -11,19 +11,68 @@ import torch -from nequip.data import AtomicData, Collater, dataset_from_config, register_fields -from nequip.scripts.deploy import load_deployed_model, R_MAX_KEY +from nequip.data import ( + AtomicData, + Collater, + dataset_from_config, + register_fields, + _register_field_prefix, +) +from nequip.scripts.deploy import load_deployed_model, R_MAX_KEY, TYPE_NAMES_KEY from nequip.scripts._logger import set_up_script_logger from nequip.scripts.train import default_config, check_code_version from nequip.utils._global_options import _set_global_options from nequip.train import Trainer, Loss, Metrics from nequip.utils import load_file, instantiate, Config - -ORIGINAL_DATASET_INDEX_KEY: str = "original_dataset_index" +ORIGINAL_DATASET_PREFIX: str = "original_dataset_" +ORIGINAL_DATASET_INDEX_KEY: str = ORIGINAL_DATASET_PREFIX + "index" register_fields(graph_fields=[ORIGINAL_DATASET_INDEX_KEY]) +def _load_deployed_or_traindir( + path: Path, device, freeze: bool = True +) -> Tuple[torch.nn.Module, bool, float, List[str]]: + loaded_deployed_model: bool = False + model_r_max = None + type_names = None + try: + model, metadata = load_deployed_model( + path, + device=device, + set_global_options=True, # don't warn that setting + freeze=freeze, + ) + # the global settings for a deployed model are set by + # set_global_options in the call to load_deployed_model + # above + model_r_max = float(metadata[R_MAX_KEY]) + type_names = metadata[TYPE_NAMES_KEY].split(" ") + loaded_deployed_model = True + except ValueError: # its not a deployed model + loaded_deployed_model = False + # we don't do this in the `except:` block to avoid "during handing of this exception another exception" + # chains if there is an issue loading the training session model. This makes the error messages more + # comprehensible: + if not loaded_deployed_model: + # Use the model config, regardless of dataset config + global_config = path.parent / "config.yaml" + global_config = Config.from_file(str(global_config), defaults=default_config) + _set_global_options(global_config) + check_code_version(global_config) + del global_config + + # load a training session model + model, model_config = Trainer.load_model_from_training_session( + traindir=path.parent, model_name=path.name + ) + model = model.to(device) + model_r_max = model_config["r_max"] + type_names = model_config["type_names"] + model.eval() + return model, loaded_deployed_model, model_r_max, type_names + + def main(args=None, running_as_script: bool = True): # in results dir, do: nequip-deploy build --train-dir . deployed.pth parser = argparse.ArgumentParser( @@ -38,7 +87,7 @@ def main(args=None, running_as_script: bool = True): Prints only the final result in `name = num` format to stdout; all other information is `logging.debug`ed to stderr. - WARNING: Please note that results of CUDA models are rarely exactly reproducible, and that even CPU models can be nondeterministic. + Please note that results of CUDA models are rarely exactly reproducible, and that even CPU models can be nondeterministic. This is very rarely important in practice, but can be unintuitive. """ ) ) @@ -68,7 +117,7 @@ def main(args=None, running_as_script: bool = True): ) parser.add_argument( "--test-indexes", - help="Path to a file containing the indexes in the dataset that make up the test set. If omitted, all data frames *not* used as training or validation data in the training session `train_dir` will be used.", + help="Path to a file containing the indexes in the dataset that make up the test set. If omitted, all data frames *not* used as training or validation data in the training session `train_dir` will be used. PyTorch, YAML, and JSON formats containing a list of integers are supported.", type=Path, default=None, ) @@ -112,6 +161,12 @@ def main(args=None, running_as_script: bool = True): type=str, default="", ) + parser.add_argument( + "--output-fields-from-original-dataset", + help="Extra fields from the ORIGINAL REFERENCE DATASET (names comma separated with no spaces) to write to the `--output` with the added prefix `original_dataset_*`", + type=str, + default="", + ) parser.add_argument( "--log", help="log file to store all the metrics and screen logging.debug", @@ -164,9 +219,20 @@ def main(args=None, running_as_script: bool = True): if args.output is not None: if args.output.suffix != ".xyz": raise ValueError("Only .xyz format for `--output` is supported.") - args.output_fields = [e for e in args.output_fields.split(",") if e != ""] + [ - ORIGINAL_DATASET_INDEX_KEY + args.output_fields_from_original_dataset = [ + e for e in args.output_fields_from_original_dataset.split(",") if e != "" ] + args.output_fields = [e for e in args.output_fields.split(",") if e != ""] + ase_all_fields = ( + args.output_fields + + [ + ORIGINAL_DATASET_PREFIX + e + for e in args.output_fields_from_original_dataset + ] + + [ORIGINAL_DATASET_INDEX_KEY] + ) + if len(args.output_fields_from_original_dataset) > 0: + _register_field_prefix(ORIGINAL_DATASET_PREFIX) output_type = "xyz" else: assert args.output_fields == "" @@ -185,7 +251,7 @@ def main(args=None, running_as_script: bool = True): logger.info(f"Using device: {device}") if device.type == "cuda": logger.info( - "WARNING: please note that models running on CUDA are usually nondeterministc and that this manifests in the final test errors; for a _more_ deterministic result, please use `--device cpu`", + "Please note that _all_ machine learning models running on CUDA hardware are generally somewhat nondeterministic and that this can manifest in small, generally unimportant variation in the final test errors.", ) if args.use_deterministic_algorithms: @@ -196,41 +262,10 @@ def main(args=None, running_as_script: bool = True): # Load model: logger.info("Loading model... ") - loaded_deployed_model: bool = False - model_r_max = None - try: - model, metadata = load_deployed_model( - args.model, - device=device, - set_global_options=True, # don't warn that setting - ) - logger.info("loaded deployed model.") - # the global settings for a deployed model are set by - # set_global_options in the call to load_deployed_model - # above - model_r_max = float(metadata[R_MAX_KEY]) - loaded_deployed_model = True - except ValueError: # its not a deployed model - loaded_deployed_model = False - # we don't do this in the `except:` block to avoid "during handing of this exception another exception" - # chains if there is an issue loading the training session model. This makes the error messages more - # comprehensible: - if not loaded_deployed_model: - # Use the model config, regardless of dataset config - global_config = args.model.parent / "config.yaml" - global_config = Config.from_file(str(global_config), defaults=default_config) - _set_global_options(global_config) - check_code_version(global_config) - del global_config - - # load a training session model - model, model_config = Trainer.load_model_from_training_session( - traindir=args.model.parent, model_name=args.model.name - ) - model = model.to(device) - logger.info("loaded model from training session") - model_r_max = model_config["r_max"] - model.eval() + model, loaded_deployed_model, model_r_max, _ = _load_deployed_or_traindir( + args.model, device=device + ) + logger.info(f" loaded{' deployed' if loaded_deployed_model else ''} model") # Load a config file logger.info( @@ -374,22 +409,27 @@ def main(args=None, running_as_script: bool = True): with torch.no_grad(): # Write output if output_type == "xyz": + output_out = out.copy() # add test frame to the output: - out[ORIGINAL_DATASET_INDEX_KEY] = torch.LongTensor( + output_out[ORIGINAL_DATASET_INDEX_KEY] = torch.LongTensor( this_batch_test_indexes ) + for field in args.output_fields_from_original_dataset: + # batch is from the original dataset + output_out[ORIGINAL_DATASET_PREFIX + field] = batch[field] # append to the file ase.io.write( output, - AtomicData.from_AtomicDataDict(out) + AtomicData.from_AtomicDataDict(output_out) .to(device="cpu") .to_ase( type_mapper=dataset.type_mapper, - extra_fields=args.output_fields, + extra_fields=ase_all_fields, ), format="extxyz", append=True, ) + del output_out # Accumulate metrics if do_metrics: diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 88b55f7e..3d10049b 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -16,31 +16,36 @@ from nequip.utils import Config from nequip.data import dataset_from_config from nequip.utils import load_file +from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS from nequip.utils.test import assert_AtomicData_equivariant from nequip.utils.versions import check_code_version +from nequip.utils.misc import get_default_device_name from nequip.utils._global_options import _set_global_options from nequip.scripts._logger import set_up_script_logger default_config = dict( root="./", - run_name="NequIP", + tensorboard=False, wandb=False, - wandb_project="NequIP", model_builders=[ "SimpleIrrepsConfig", "EnergyModel", "PerSpeciesRescale", - "ForceOutput", + "StressForceOutput", "RescaleEnergyEtc", ], dataset_statistics_stride=1, - default_dtype="float32", - allow_tf32=False, # TODO: until we understand equivar issues + device=get_default_device_name(), + default_dtype="float64", + model_dtype="float32", + allow_tf32=True, verbose="INFO", model_debug_mode=False, equivariance_test=False, grad_anomaly_mode=False, + gpu_oom_offload=False, append=False, + warn_unused=False, _jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286 # Quote from eelison in PyTorch slack: # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA @@ -51,7 +56,13 @@ # We default to DYNAMIC alone because the number of edges is always dynamic, # even if the number of atoms is fixed: _jit_fusion_strategy=[("DYNAMIC", 3)], + # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now: + # TODO: still default to NNC on CPU regardless even if change this for GPU + # TODO: default for ROCm? + _jit_fuser="fuser1", ) +# All default_config keys are valid / requested +_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys()) def main(args=None, running_as_script: bool = True): @@ -75,7 +86,22 @@ def main(args=None, running_as_script: bool = True): # Train trainer.save() - trainer.train() + if config.get("gpu_oom_offload", False): + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available; --gpu-oom-offload doesn't make sense." + ) + warnings.warn( + "! GPU OOM Offloading is ON:\n" + "This is meant for training models that would be impossible otherwise due to OOM.\n" + "Note that this comes at a speed cost and SHOULD NOT be used if your training fits in GPU memory without it.\n" + "Please also consider whether a smaller model is a more appropriate solution!\n" + "Also, a warning from PyTorch: 'If you overuse pinned memory, it can cause serious problems when running low on RAM!'" + ) + with torch.autograd.graph.save_on_cpu(pin_memory=True): + trainer.train() + else: + trainer.train() return @@ -104,16 +130,32 @@ def parse_command_line(args=None): help="enable PyTorch autograd anomaly mode to debug NaN gradients. Do not use for production training!", action="store_true", ) + parser.add_argument( + "--gpu-oom-offload", + help="Use `torch.autograd.graph.save_on_cpu` to offload intermediate tensors to CPU (host) memory in order to train models that would be impossible otherwise due to OOM. Note that this comes as at a speed cost and SHOULD NOT be used if your training fits in GPU memory without it. Please also consider whether a smaller model is a more appropriate solution.", + action="store_true", + ) parser.add_argument( "--log", help="log file to store all the screen logging", type=Path, default=None, ) + parser.add_argument( + "--warn-unused", + help="Warn instead of error when the config contains unused keys", + action="store_true", + ) args = parser.parse_args(args=args) config = Config.from_file(args.config, defaults=default_config) - for flag in ("model_debug_mode", "equivariance_test", "grad_anomaly_mode"): + for flag in ( + "model_debug_mode", + "equivariance_test", + "grad_anomaly_mode", + "warn_unused", + "gpu_oom_offload", + ): config[flag] = getattr(args, flag) or config[flag] return config @@ -123,22 +165,28 @@ def fresh_start(config): # we use add_to_config cause it's a fresh start and need to record it check_code_version(config, add_to_config=True) _set_global_options(config) + if config["default_dtype"] != "float64": + warnings.warn( + f"default_dtype={config['default_dtype']} but we strongly recommend float64" + ) # = Make the trainer = if config.wandb: + import wandb # noqa: F401 - from nequip.train.trainer_wandb import TrainerWandB + from nequip.train.trainer_wandb import TrainerWandB as Trainer # download parameters from wandb in case of sweeping from nequip.utils.wandb import init_n_update config = init_n_update(config) - trainer = TrainerWandB(model=None, **dict(config)) + elif config.tensorboard: + from nequip.train.trainer_tensorboard import TrainerTensorBoard as Trainer else: from nequip.train.trainer import Trainer - trainer = Trainer(model=None, **dict(config)) + trainer = Trainer(model=None, **Config.as_dict(config)) # what is this # to update wandb data? @@ -165,9 +213,6 @@ def fresh_start(config): ) logging.info("Successfully built the network...") - # by doing this here we check also any keys custom builders may have added - _check_old_keys(config) - # Equivar test if config.equivariance_test > 0: n_train: int = len(trainer.dataset_train) @@ -195,6 +240,19 @@ def fresh_start(config): # Store any updated config information in the trainer trainer.update_kwargs(config) + # Only run the unused check as a callback after the trainer has + # initialized everything (metrics, early stopping, etc.) + def _unused_check(): + unused = config._unused_keys() + if len(unused) > 0: + message = f"The following keys in the config file were not used, did you make a typo?: {', '.join(unused)}. (If this sounds wrong, please file an issue. You can turn this error into a warning with `--warn-unused`, but please make sure that the key really is correctly spelled and used!.)" + if config.warn_unused: + warnings.warn(message) + else: + raise KeyError(message) + + trainer._post_init_callback = _unused_check + return trainer @@ -262,16 +320,5 @@ def restart(config): return trainer -def _check_old_keys(config) -> None: - """check ``config`` for old/depricated keys and emit corresponding errors/warnings""" - # compile_model - k = "compile_model" - if k in config: - if config[k]: - raise ValueError("the `compile_model` option has been removed") - else: - warnings.warn("the `compile_model` option has been removed") - - if __name__ == "__main__": main(running_as_script=True) diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 6df59fe3..6442c0d4 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -4,7 +4,7 @@ import torch.nn from torch_runstats.scatter import scatter, scatter_mean -from nequip.data import AtomicDataDict +from nequip.data import AtomicDataDict, _GRAPH_FIELDS from nequip.utils import instantiate_from_cls_name @@ -44,17 +44,20 @@ def __call__( key: str, mean: bool = True, ): + ref = ref[key] + # make sure prediction is promoted to dtype of reference + pred = pred[key].to(ref.dtype) # zero the nan entries - has_nan = self.ignore_nan and torch.isnan(ref[key].mean()) + has_nan = self.ignore_nan and torch.isnan(ref.mean()) if has_nan: - not_nan = (ref[key] == ref[key]).int() - loss = self.func(pred[key], torch.nan_to_num(ref[key], nan=0.0)) * not_nan + not_nan = (ref == ref).int() + loss = self.func(pred, torch.nan_to_num(ref, nan=0.0)) * not_nan if mean: return loss.sum() / not_nan.sum() else: return loss else: - loss = self.func(pred[key], ref[key]) + loss = self.func(pred, ref) if mean: return loss.mean() else: @@ -69,28 +72,34 @@ def __call__( key: str, mean: bool = True, ): + if key not in _GRAPH_FIELDS: + raise RuntimeError( + f"Doesn't make sense to do a `PerAtom` loss on field `{key}`, which isn't registered as a graph (global) field. If it is a graph-level field, register it with `graph_fields: [\"{key}\"]`; otherwise you don't need to specify `PerAtom` for loss on per-node fields." + ) + ref_dict = ref + ref = ref[key] + # make sure prediction is promoted to dtype of reference + pred = pred[key].to(ref.dtype) # zero the nan entries - has_nan = self.ignore_nan and torch.isnan(ref[key].sum()) - N = torch.bincount(ref[AtomicDataDict.BATCH_KEY]) + has_nan = self.ignore_nan and torch.isnan(ref.sum()) + N = torch.bincount(ref_dict[AtomicDataDict.BATCH_KEY]) N = N.reshape((-1, 1)) if has_nan: - not_nan = (ref[key] == ref[key]).int() - loss = ( - self.func(pred[key], torch.nan_to_num(ref[key], nan=0.0)) * not_nan / N - ) + not_nan = (ref == ref).int() + loss = self.func(pred, torch.nan_to_num(ref, nan=0.0)) * not_nan / N if self.func_name == "MSELoss": loss = loss / N - assert loss.shape == pred[key].shape # [atom, dim] + assert loss.shape == pred.shape # [atom, dim] if mean: return loss.sum() / not_nan.sum() else: return loss else: - loss = self.func(pred[key], ref[key]) + loss = self.func(pred, ref) loss = loss / N if self.func_name == "MSELoss": loss = loss / N - assert loss.shape == pred[key].shape # [atom, dim] + assert loss.shape == pred.shape # [atom, dim] if mean: return loss.mean() else: @@ -113,20 +122,22 @@ def __call__( ): if not mean: raise NotImplementedError("Cannot handle this yet") + ref = ref[key] + # make sure prediction is promoted to dtype of reference + pred_dict = pred + pred = pred[key].to(ref.dtype) - has_nan = self.ignore_nan and torch.isnan(ref[key].mean()) + has_nan = self.ignore_nan and torch.isnan(ref.mean()) if has_nan: - not_nan = (ref[key] == ref[key]).int() - per_atom_loss = ( - self.func(pred[key], torch.nan_to_num(ref[key], nan=0.0)) * not_nan - ) + not_nan = (ref == ref).int() + per_atom_loss = self.func(pred, torch.nan_to_num(ref, nan=0.0)) * not_nan else: - per_atom_loss = self.func(pred[key], ref[key]) + per_atom_loss = self.func(pred, ref) reduce_dims = tuple(i + 1 for i in range(len(per_atom_loss.shape) - 1)) - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) + spe_idx = pred_dict[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) if has_nan: if len(reduce_dims) > 0: per_atom_loss = per_atom_loss.sum(dim=reduce_dims) diff --git a/nequip/train/callbacks/loss_schedule.py b/nequip/train/callbacks/loss_schedule.py new file mode 100644 index 00000000..edd6f173 --- /dev/null +++ b/nequip/train/callbacks/loss_schedule.py @@ -0,0 +1,54 @@ +from typing import Dict, List, Tuple +from dataclasses import dataclass +import numpy as np + +from nequip.train import Trainer, Loss + +# Making this a dataclass takes care of equality operators, handing restart consistency checks + + +@dataclass +class SimpleLossSchedule: + """Schedule `loss_coeffs` through a training run. + + To use this in a training, set in your YAML file: + + start_of_epoch_callbacks: + - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[30, {"forces": 1.0, "total_energy": 0.0}], [30, {"forces": 0.0, "total_energy": 1.0}]]} + + This funny syntax tells PyYAML to construct an object of this class. + + Each entry in the schedule is a tuple of the 1-based epoch index to start that loss coefficient set at, and a dict of loss coefficients. + """ + + schedule: List[Tuple[int, Dict[str, float]]] = None + + def __call__(self, trainer: Trainer): + assert ( + self in trainer._start_of_epoch_callbacks + ), "must be start not end of epoch" + # user-facing 1 based indexing of epochs rather than internal zero based + iepoch: int = trainer.iepoch + 1 + if iepoch < 1: # initial validation epoch is 0 in user-facing indexing + return + loss_function: Loss = trainer.loss + + assert self.schedule is not None + schedule_start_epochs = np.asarray([e[0] for e in self.schedule]) + # make sure they are ascending + assert len(schedule_start_epochs) >= 1 + assert schedule_start_epochs[0] >= 2, "schedule must start at epoch 2 or later" + assert np.all( + (schedule_start_epochs[1:] - schedule_start_epochs[:-1]) > 0 + ), "schedule start epochs must be strictly ascending" + # we are running at _start_ of epoch, so we need to apply the right change for the current epoch + current_change_idex = np.searchsorted(schedule_start_epochs, iepoch + 1) - 1 + # ^ searchsorted 3 in [2, 10, 19] would return 1, for example + # but searching 2 in [2, 10, 19] gives 0, so we actually search iepoch + 1 to always be ahead of the start + # apply the current change to handle restarts + if current_change_idex >= 0: + new_coeffs = self.schedule[current_change_idex][1] + assert ( + loss_function.coeffs.keys() == new_coeffs.keys() + ), "all coeff schedules must contain all loss terms" + loss_function.coeffs.update(new_coeffs) diff --git a/nequip/train/loss.py b/nequip/train/loss.py index 1420fc22..fe5144da 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -39,10 +39,7 @@ class Loss: def __init__( self, coeffs: Union[dict, str, List[str]], - coeff_schedule: str = "constant", ): - - self.coeff_schedule = coeff_schedule self.coeffs = {} self.funcs = {} self.keys = [] diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 55efec32..bdfb4f17 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -26,7 +26,14 @@ import torch from torch_ema import ExponentialMovingAverage -from nequip.data import DataLoader, AtomicData, AtomicDataDict, AtomicDataset +from nequip.data import ( + DataLoader, + PartialSampler, + AtomicData, + AtomicDataDict, + AtomicDataset, +) +from nequip.nn import GraphModel from nequip.utils import ( Output, Config, @@ -38,10 +45,11 @@ atomic_write, finish_all_writes, atomic_write_group, - dtype_from_name, ) from nequip.utils.versions import check_code_version from nequip.model import model_from_config +from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS +from nequip.utils.misc import get_default_device_name from .loss import Loss, LossStat from .metrics import Metrics @@ -150,6 +158,7 @@ class Trainer: validation_batch_size (int): batch size for evaluating the model for validation shuffle (bool): parameters for dataloader n_train (int): # of frames for training + n_train_per_epoch (optional int): how many frames from `n_train` to use each epoch; see `PartialSampler`. When `None`, all `n_train` frames will be used each epoch. n_val (int): # of frames for validation exclude_keys (list): fields from dataset to ignore. dataloader_num_workers (int): `num_workers` for the `DataLoader`s @@ -211,11 +220,13 @@ class Trainer: lr_scheduler_module = torch.optim.lr_scheduler optim_module = torch.optim + model: GraphModel + def __init__( self, model, model_builders: Optional[list] = [], - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = get_default_device_name(), seed: Optional[int] = None, dataset_seed: Optional[int] = None, loss_coeffs: Union[dict, str] = AtomicDataDict.TOTAL_ENERGY_KEY, @@ -240,17 +251,19 @@ def __init__( validation_batch_size: int = 5, shuffle: bool = True, n_train: Optional[int] = None, + n_train_per_epoch: Optional[int] = None, n_val: Optional[int] = None, dataloader_num_workers: int = 0, train_idcs: Optional[list] = None, val_idcs: Optional[list] = None, train_val_split: str = "random", init_callbacks: list = [], + start_of_epoch_callbacks: list = [], end_of_epoch_callbacks: list = [], end_of_batch_callbacks: list = [], end_of_train_callbacks: list = [], final_callbacks: list = [], - log_batch_freq: int = 1, + log_batch_freq: int = 100, log_epoch_freq: int = 1, save_checkpoint_freq: int = -1, save_ema_checkpoint_freq: int = -1, @@ -269,6 +282,8 @@ def __init__( for key in self.init_keys: setattr(self, key, locals()[key]) _local_kwargs[key] = locals()[key] + # all init_keys of the Trainer are valid config keys + _GLOBAL_ALL_ASKED_FOR_KEYS.add(key) self.ema = None @@ -295,13 +310,14 @@ def __init__( self.trainer_save_path = output.generate_file("trainer.pth") self.config_path = self.output.generate_file("config.yaml") - if seed is not None: - torch.manual_seed(seed) - np.random.seed(seed) + if seed is None: + raise ValueError("seed is required") + + torch.manual_seed(seed) + np.random.seed(seed) self.dataset_rng = torch.Generator() - if dataset_seed is not None: - self.dataset_rng.manual_seed(dataset_seed) + self.dataset_rng.manual_seed(dataset_seed if dataset_seed is not None else seed) self.logger.info(f"Torch device: {self.device}") self.torch_device = torch.device(self.device) @@ -330,26 +346,12 @@ def __init__( self.train_on_keys = self.loss.keys if train_on_keys is not None: assert set(train_on_keys) == set(self.train_on_keys) - self._remove_from_model_input = set(self.train_on_keys) - if ( - len( - self._remove_from_model_input.intersection( - AtomicDataDict.ALL_ENERGY_KEYS - ) - ) - > 0 - ): - # if we are training on _any_ of the energy quantities (energy, force, partials, stress, etc.) - # then none of them should be fed into the model - self._remove_from_model_input = self._remove_from_model_input.union( - AtomicDataDict.ALL_ENERGY_KEYS - ) - if kwargs.get("_override_allow_truth_label_inputs", False): - # needed for unit testing models - self._remove_from_model_input = set() # load all callbacks self._init_callbacks = [load_callable(callback) for callback in init_callbacks] + self._start_of_epoch_callbacks = [ + load_callable(callback) for callback in start_of_epoch_callbacks + ] self._end_of_epoch_callbacks = [ load_callable(callback) for callback in end_of_epoch_callbacks ] @@ -671,6 +673,19 @@ def load_model_from_training_session( device="cpu", config_dictionary: Optional[dict] = None, ) -> Tuple[torch.nn.Module, Config]: + """Load a model from a training session. + + Note that this uses ``model_from_config`` internally and is thus not thread safe. + + Args: + traindir: the training session + model_name: which checkpoint to load; defaults to ``best_model.pth`` + device: target device to load to, defaults to ``cpu`` + config_dictionary: optionally use this config instead of ``traindir/config.yaml`` + + Returns: + (model, config) + """ traindir = str(traindir) model_name = str(model_name) @@ -679,21 +694,14 @@ def load_model_from_training_session( else: config = Config.from_file(traindir + "/config.yaml") + # model_from_config takes care of dtypes already model = model_from_config( config=config, initialize=False, ) - if model is not None: # TODO: why would it be? - # TODO: this is not exactly equivalent to building with - # this set as default dtype... does it matter? - model.to( - device=torch.device(device), - dtype=dtype_from_name(config.default_dtype), - ) - model_state_dict = torch.load( - traindir + "/" + model_name, map_location=device - ) - model.load_state_dict(model_state_dict) + model.to(device=torch.device(device)) + model_state_dict = torch.load(traindir + "/" + model_name, map_location=device) + model.load_state_dict(model_state_dict) return model, config @@ -701,6 +709,7 @@ def init(self): """initialize optimizer""" if self.model is None: return + assert isinstance(self.model, GraphModel) self.model.to(self.torch_device) @@ -710,12 +719,6 @@ def init(self): f"Number of trainable weights: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}" ) - self.rescale_layers = [] - outer_layer = self.model - while hasattr(outer_layer, "unscale"): - self.rescale_layers.append(outer_layer) - outer_layer = getattr(outer_layer, "model", None) - self.init_objects() self._initialized = True @@ -773,6 +776,9 @@ def train(self): self.init_metrics() + if getattr(self, "_post_init_callback", None) is not None: + self._post_init_callback() + while not self.stop_cond: self.epoch_step() @@ -799,32 +805,22 @@ def batch_step(self, data, validation=False): data = data.to(self.torch_device) data = AtomicData.to_AtomicDataDict(data) - data_unscaled = data - for layer in self.rescale_layers: - # This means that self.model is RescaleOutputs - # this will normalize the targets - # in validation (eval mode), it does nothing - # in train mode, if normalizes the targets - data_unscaled = layer.unscale(data_unscaled) + # this will normalize the targets + # in both validation and train we want targets normalized _for the loss_ + data_for_loss = self.model.unscale(data, force_process=True) # Run model # We make a shallow copy of the input dict in case the model modifies it - input_data = { - k: v - for k, v in data_unscaled.items() - if k not in self._remove_from_model_input - } - out = self.model(input_data) - del input_data + out = self.model(data_for_loss) # If we're in evaluation mode (i.e. validation), then - # data_unscaled's target prop is unnormalized, and out's has been rescaled to be in the same units - # If we're in training, data_unscaled's target prop has been normalized, and out's hasn't been touched, so they're both in normalized units - # Note that either way all normalization was handled internally by RescaleOutput + # data_for_loss's target prop is unnormalized, and out's has been rescaled to be in the same units + # If we're in training, data_for_loss's target prop has been normalized, and out's hasn't been touched, so they're both in normalized units + # Note that either way all normalization was handled internally by GraphModel via RescaleOutput if not validation: # Actually do an optimization step, since we're training: - loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) + loss, loss_contrib = self.loss(pred=out, ref=data_for_loss) # see https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-parameter-grad-none-instead-of-model-zero-grad-or-optimizer-zero-grad self.optim.zero_grad(set_to_none=True) loss.backward() @@ -846,25 +842,26 @@ def batch_step(self, data, validation=False): with torch.no_grad(): if validation: - scaled_out = out - _data_unscaled = data - for layer in self.rescale_layers: - # loss function always needs to be in normalized unit - scaled_out = layer.unscale(scaled_out, force_process=True) - _data_unscaled = layer.unscale(_data_unscaled, force_process=True) - loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) + # loss function always needs to be in normalized unit + normalized_units_out = self.model.unscale(out, force_process=True) + # data_for_loss is always forced into normalized units + loss, loss_contrib = self.loss( + pred=normalized_units_out, ref=data_for_loss + ) + del normalized_units_out + # everything else is already in real units for metrics, so do nothing else: # If we are in training mode, we need to bring the prediction - # into real units - for layer in self.rescale_layers[::-1]: - out = layer.scale(out, force_process=True) + # into real units for metrics + out = self.model.scale(out, force_process=True) # save metrics stats self.batch_losses = self.loss_stat(loss, loss_contrib) - # in validation mode, data is in real units and the network scales + # in validation mode, reference data is in real units and the network scales # out to be in real units interally. - # in training mode, data is still in real units, and we rescaled - # out to be in real units above. + # in training mode, reference data is still in real units, and we rescaled + # network predicted out to be in real units right above + # thus, we get metrics in real units always: self.batch_metrics = self.metrics(pred=out, ref=data) @property @@ -894,12 +891,18 @@ def reset_metrics(self): self.metrics.to(self.torch_device) def epoch_step(self): + for callback in self._start_of_epoch_callbacks: + callback(self) dataloaders = {TRAIN: self.dl_train, VALIDATION: self.dl_val} categories = [TRAIN, VALIDATION] if self.iepoch >= 0 else [VALIDATION] dataloaders = [ dataloaders[c] for c in categories ] # get the right dataloaders for the catagories we actually run + if TRAIN in categories: + # We have to step the sampler so it knows what epoch it is + self.dl_train_sampler.step_epoch(self.iepoch) + self.metrics_dict = {} self.loss_dict = {} @@ -1221,10 +1224,21 @@ def set_dataset( # use the right randomness generator=self.dataset_rng, ) + if self.n_train_per_epoch is not None: + assert self.n_train_per_epoch % self.batch_size == 0 + self.dl_train_sampler = PartialSampler( + data_source=self.dataset_train, + # training should shuffle (if enabled) + shuffle=self.shuffle, + # if n_train_per_epoch is None (default), it's set to len(self.dataset_train) == n_train + # i.e. use all `n_train` frames each epoch + num_samples_per_epoch=self.n_train_per_epoch, + generator=self.dataset_rng, + ) self.dl_train = DataLoader( dataset=self.dataset_train, - shuffle=self.shuffle, # training should shuffle batch_size=self.batch_size, + sampler=self.dl_train_sampler, **dl_kwargs, ) # validation, on the other hand, shouldn't shuffle diff --git a/nequip/train/trainer_tensorboard.py b/nequip/train/trainer_tensorboard.py new file mode 100644 index 00000000..de76cbe9 --- /dev/null +++ b/nequip/train/trainer_tensorboard.py @@ -0,0 +1,31 @@ +from torch.utils.tensorboard import SummaryWriter + +from .trainer import Trainer, TRAIN, VALIDATION + + +class TrainerTensorBoard(Trainer): + """Trainer class that adds WandB features""" + + def end_of_epoch_log(self): + Trainer.end_of_epoch_log(self) + kwargs = dict( + global_step=self.iepoch, walltime=self.mae_dict["cumulative_wall"] + ) + for k, v in self.mae_dict.items(): + terms = k.split("_") + if terms[0] in [TRAIN, VALIDATION]: + header = "/".join(terms[1:]) + self.tb_writer.add_scalar(f"{header}/{terms[0]}", v, **kwargs) + elif k not in ["cumulative_wall", "epoch"]: + self.tb_writer.add_scalar(k, v, **kwargs) + self.tb_writer.flush() + + def init(self): + super().init() + + if not self._initialized: + return + + self.tb_writer = SummaryWriter( + log_dir=f"{self.output.root}/tb_summary/{self.output.run_name}", + ) diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index e7dd0912..46ab22ec 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -14,7 +14,7 @@ from .config import Config from .output import Output from .modules import find_first_of_type -from .misc import dtype_from_name +from .misc import dtype_from_name, torch_default_dtype __all__ = [ instantiate_from_cls_name, @@ -30,4 +30,5 @@ Output, find_first_of_type, dtype_from_name, + torch_default_dtype, ] diff --git a/nequip/utils/_global_options.py b/nequip/utils/_global_options.py index 907a9ed9..bc5bc2d9 100644 --- a/nequip/utils/_global_options.py +++ b/nequip/utils/_global_options.py @@ -1,4 +1,6 @@ import warnings +from packaging import version +import os import torch @@ -9,6 +11,7 @@ from .misc import dtype_from_name from .auto_init import instantiate from .test import set_irreps_debug +from .config import Config # for multiprocessing, we need to keep track of our latest global options so @@ -35,7 +38,7 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: """ # update these options into the latest global config. global _latest_global_config - _latest_global_config.update(dict(config)) + _latest_global_config.update(Config.as_dict(config)) # Set TF32 support # See https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if torch.cuda.is_available() and "allow_tf32" in config: @@ -48,7 +51,16 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: torch.backends.cuda.matmul.allow_tf32 = config["allow_tf32"] torch.backends.cudnn.allow_tf32 = config["allow_tf32"] - if int(torch.__version__.split(".")[1]) >= 11: + # Temporary warning due to unresolved upstream issue + torch_version = version.parse(torch.__version__) + if torch_version < version.parse("1.11"): + warnings.warn("We currently recommend the use of PyTorch 1.11") + elif torch_version > version.parse("1.11"): + warnings.warn( + "!! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." + ) + + if torch_version >= version.parse("1.11"): # PyTorch >= 1.11 k = "_jit_fusion_strategy" if k in config: @@ -70,11 +82,41 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: f"Setting the GLOBAL value for jit bailout depth to `{new_depth}` which is different than the previous value of `{old_depth}`" ) + # Deal with fusers + # The default PyTorch fuser changed to nvFuser in 1.12 + # fuser1 is NNC, fuser2 is nvFuser + # See https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#fusers + # And https://github.com/pytorch/pytorch/blob/e0a0f37a11164f59b42bc80a6f95b54f722d47ce/torch/jit/_fuser.py#L46 + # Also https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/cuda/README.md + # Also https://github.com/pytorch/pytorch/blob/66fb83293e6a6f527d3fde632e3547fda20becea/torch/csrc/jit/OVERVIEW.md?plain=1#L1201 + # https://github.com/search?q=repo%3Apytorch%2Fpytorch%20PYTORCH_JIT_USE_NNC_NOT_NVFUSER&type=code + # We follow the approach they have explicitly built for disabling nvFuser in favor of NNC: + # https://github.com/pytorch/pytorch/blob/66fb83293e6a6f527d3fde632e3547fda20becea/torch/csrc/jit/codegen/cuda/README.md?plain=1#L214 + # + # There are three ways to disable nvfuser. Listed below with descending priorities: + # - Force using NNC instead of nvfuser for GPU fusion with env variable `export PYTORCH_JIT_USE_NNC_NOT_NVFUSER=1`. + # - Disabling nvfuser with torch API `torch._C._jit_set_nvfuser_enabled(False)`. + # - Disable nvfuser with env variable `export PYTORCH_JIT_ENABLE_NVFUSER=0`. + # + k = "PYTORCH_JIT_USE_NNC_NOT_NVFUSER" + if k in os.environ: + warnings.warn( + "Do NOT manually set PYTORCH_JIT_USE_NNC_NOT_NVFUSER=0 unless you know exactly what you're doing!" + ) + else: + os.environ[k] = "1" + # TODO: warn_on_override for the rest here? if config.get("model_debug_mode", False): set_irreps_debug(enabled=True) if "default_dtype" in config: - torch.set_default_dtype(dtype_from_name(config["default_dtype"])) + old_dtype = torch.get_default_dtype() + new_dtype = dtype_from_name(config["default_dtype"]) + if warn_on_override and old_dtype != new_dtype: + warnings.warn( + f"Setting the GLOBAL value for torch.set_default_dtype to `{new_dtype}` which is different than the previous value of `{old_dtype}`" + ) + torch.set_default_dtype(new_dtype) if config.get("grad_anomaly_mode", False): torch.autograd.set_detect_anomaly(True) diff --git a/nequip/utils/auto_init.py b/nequip/utils/auto_init.py index 8a9a9917..157c9ce4 100644 --- a/nequip/utils/auto_init.py +++ b/nequip/utils/auto_init.py @@ -2,7 +2,7 @@ import inspect import logging -from .config import Config +from .config import Config, _GLOBAL_ALL_ASKED_FOR_KEYS def instantiate_from_cls_name( @@ -140,7 +140,7 @@ def instantiate( if k not in key_mapping["optional"] } - final_optional_args = dict(config) + final_optional_args = Config.as_dict(config) # for nested argument, it is possible that the positional args contain unnecesary keys if len(parent_builders) > 0: @@ -213,21 +213,32 @@ def instantiate( for t in key_mapping: key_mapping[t].pop(key, None) + # debug info + if len(parent_builders) == 0: + # ^ we only want to log or consume arguments for the "unused keys" check + # if this is a root-level build. For subbuilders, we don't want to log + # or, worse, mark keys without prefixes as consumed. + logging.debug( + f"{'get args for' if return_args_only else 'instantiate'} {builder.__name__}" + ) + for t in key_mapping: + for k, v in key_mapping[t].items(): + string = f" {t:>10s}_args : {k:>50s}" + # key mapping tells us how values got from the + # users config (v) to the object being built (k) + # thus v is by definition a valid key + _GLOBAL_ALL_ASKED_FOR_KEYS.add(v) + if k != v: + string += f" <- {v:>50s}" + logging.debug(string) + logging.debug(f"...{builder.__name__}_param = dict(") + logging.debug(f"... optional_args = {final_optional_args},") + logging.debug(f"... positional_args = {positional_args})") + + # Short circuit for return_args_only if return_args_only: return key_mapping, final_optional_args - - # debug info - logging.debug(f"instantiate {builder.__name__}") - for t in key_mapping: - for k, v in key_mapping[t].items(): - string = f" {t:>10s}_args : {k:>50s}" - if k != v: - string += f" <- {v:>50s}" - logging.debug(string) - logging.debug(f"...{builder.__name__}_param = dict(") - logging.debug(f"... optional_args = {final_optional_args},") - logging.debug(f"... positional_args = {positional_args})") - + # Otherwise, actually build the thing: try: instance = builder(**positional_args, **final_optional_args) except Exception as e: diff --git a/nequip/utils/config.py b/nequip/utils/config.py index d13e0546..ca79f576 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -34,6 +34,8 @@ If a parameter is updated, the updated value will be formatted back to the same type. """ +from typing import Set, Dict, Any, List + import inspect from copy import deepcopy @@ -42,7 +44,12 @@ from nequip.utils.savenload import save_file, load_file +_GLOBAL_ALL_ASKED_FOR_KEYS: Set[str] = set() + + class Config(object): + _items: Dict[str, Any] + def __init__( self, config: Optional[dict] = None, @@ -76,10 +83,20 @@ def keys(self): def _as_dict(self): return self._items - def as_dict(self): - return dict(self) + @staticmethod + def as_dict(obj): + # don't use `dict(self)`, since that + # calls __getitem__ + if isinstance(obj, dict): + return obj.copy() + elif isinstance(obj, Config): + return obj._items.copy() + else: + raise TypeError def __getitem__(self, key): + # any requested key is a valid key + _GLOBAL_ALL_ASKED_FOR_KEYS.add(key) return self._items[key] def get_type(self, key): @@ -115,7 +132,6 @@ def allow_list(self): return self._allow_list def __setitem__(self, key, val): - # typehint if key.endswith("_type") and key.startswith("_"): @@ -157,6 +173,7 @@ def __contains__(self, key): return key in self._items def pop(self, *args): + _GLOBAL_ALL_ASKED_FOR_KEYS.add(args[0]) return self._items.pop(*args) def update_w_prefix( @@ -187,7 +204,7 @@ def update_w_prefix( keys = self.update(prefix_dict, allow_val_change=allow_val_change) keys = {k: f"{prefix}_{k}" for k in keys} - for suffix in ["params", "kwargs"]: + for suffix in ["kwargs"]: if f"{prefix}_{suffix}" in dictionary: key3 = self.update( dictionary[f"{prefix}_{suffix}"], @@ -227,6 +244,7 @@ def update(self, dictionary: dict, allow_val_change=None): return set(keys) - set([None]) def get(self, *args): + _GLOBAL_ALL_ASKED_FOR_KEYS.add(args[0]) return self._items.get(*args) def persist(self): @@ -254,7 +272,17 @@ def save(self, filename: str, format: Optional[str] = None): @staticmethod def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}): - """Load arguments from file""" + """Load arguments from file + + Has support for including another config file as a baseline with: + ``` + # example of using another config as a baseline and overriding only selected options + # this option will read in configs/minimal.yaml and take ALL keys from that file + include_file_as_baseline_config: configs/minimal.yaml + # keys specified in this file WILL OVERRIDE keys from the `include_file_as_baseline_config` file + l_max: 1 # overrides l_max: 2 in minimal.yaml + ``` + """ supported_formats = {"yaml": ("yml", "yaml"), "json": "json"} dictionary = load_file( @@ -262,6 +290,23 @@ def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}): filename=filename, enforced_format=format, ) + k: str = "include_file_as_baseline_config" + if k in dictionary: + # allow one level of subloading + baseline_fname = dictionary.pop(k) + dictionary_baseline = load_file( + supported_formats=supported_formats, + filename=baseline_fname, + enforced_format=format, + ) + if k in dictionary_baseline: + raise NotImplementedError( + f"Multiple levels of `{k}` are not allowed, but {baseline_fname} contained `{k}`" + ) + # override baseline options with the main config + dictionary_baseline.update(dictionary) + dictionary = dictionary_baseline + del dictionary_baseline, baseline_fname return Config.from_dict(dictionary, defaults) @staticmethod @@ -338,3 +383,10 @@ def from_function(function, remove_kwargs=False): return Config(config=default_params, allow_list=param_keys) load = from_file + + def _get_nomark(self, key: str) -> Any: + return self._items.get(key) + + def _unused_keys(self) -> List[str]: + unused = [k for k in self.keys() if k not in _GLOBAL_ALL_ASKED_FOR_KEYS] + return unused diff --git a/nequip/utils/gmm.py b/nequip/utils/gmm.py new file mode 100644 index 00000000..8a957826 --- /dev/null +++ b/nequip/utils/gmm.py @@ -0,0 +1,142 @@ +from typing import Optional, Union + +import math +import torch +import numpy as np +from e3nn.util.jit import compile_mode + + +@torch.jit.script +def _compute_log_det_cholesky(matrix_chol: torch.Tensor, n_features: int): + """Compute the log-det of the cholesky decomposition of matrices.""" + + n_components = matrix_chol.size(dim=0) + + # https://github.com/scikit-learn/scikit-learn/blob/d9cfe3f6b1c58dd253dc87cb676ce5171ff1f8a1/sklearn/mixture/_gaussian_mixture.py#L379 + log_det_chol = torch.sum( + torch.log(matrix_chol.view(n_components, -1)[:, :: n_features + 1]), dim=1 + ) + + return log_det_chol + + +@torch.jit.script +def _estimate_log_gaussian_prob( + X: torch.Tensor, means: torch.Tensor, precisions_chol: torch.Tensor +): + """Estimate the log Gaussian probability.""" + + n_features = X.size(dim=1) + + # https://github.com/scikit-learn/scikit-learn/blob/d9cfe3f6b1c58dd253dc87cb676ce5171ff1f8a1/sklearn/mixture/_gaussian_mixture.py#L423 + log_det = _compute_log_det_cholesky(precisions_chol, n_features) + + # dim(X) = [n_sample, n_feature] + # dim(precisions_chol) = [n_component, n_feature, n_feature] + # [n_sample, 1, n_feature] - [1, n_component, n_feature] = [n_sample, n_component, n_feature] + # dim(X_centered) = [n_sample, n_component, n_feature] + X_centered = X.unsqueeze(-2) - means.unsqueeze(0) + log_prob = ( + torch.einsum("zci,cij->zcj", X_centered, precisions_chol).square().sum(dim=-1) + ) + + # https://github.com/scikit-learn/scikit-learn/blob/d9cfe3f6b1c58dd253dc87cb676ce5171ff1f8a1/sklearn/mixture/_gaussian_mixture.py#L454 + return -0.5 * (n_features * math.log(2 * math.pi) + log_prob) + log_det + + +@compile_mode("script") +class GaussianMixture(torch.nn.Module): + """Calculate NLL of samples under a Gaussian Mixture Model (GMM). + + Supports fitting the GMM outside of PyTorch using `sklearn`. + """ + + covariance_type: str + n_components: int + n_features: int + seed: int + + def __init__( + self, + n_features: int, + n_components: Optional[int] = 0, + covariance_type: str = "full", + ): + super(GaussianMixture, self).__init__() + assert covariance_type in ( + "full", + ), f"covariance type was {covariance_type}, should be full" + self.covariance_type = covariance_type + self.n_components = n_components + self.n_features = n_features + + self.register_buffer("means", torch.Tensor()) + self.register_buffer("weights", torch.Tensor()) + self.register_buffer("covariances", torch.Tensor()) + self.register_buffer("precisions_cholesky", torch.Tensor()) + + @torch.jit.export + def is_fit(self) -> bool: + return self.weights.numel() != 0 + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """Compute the NLL of samples ``X`` under the GMM.""" + + # Check if model has been fitted + assert self.is_fit(), "model has not been fitted" + + estimated_log_probs = _estimate_log_gaussian_prob( + X, self.means, self.precisions_cholesky + ) + estimated_weights = torch.log(self.weights) + return -torch.logsumexp(estimated_log_probs + estimated_weights, dim=1) + + @torch.jit.unused + def fit( + self, + X: torch.Tensor, + max_components: int = 50, + rng: Optional[Union[torch.Generator, int]] = None, + ) -> None: + """Fit the GMM to the samples `X` using sklearn.""" + from sklearn import mixture + + # if RNG is an int, just use it as a seed; + # if RNG is None, use the current torch random state; + # if RNG is a torch.Generator, use that to generate an int seed for sklearn + # this way, this is by default seeded by torch without setting the numpy or sklearn seeds + random_state = ( + rng + if isinstance(rng, int) + else torch.randint(2**16, (1,), generator=rng).item() + ) + + gmm_kwargs = dict( + covariance_type=self.covariance_type, + random_state=random_state, + ) + + # If self.n_components is not provided (i.e, 0), set number of Gaussian + # components using BIC. The number of components should not exceed the + # number of samples in X and is capped at a heuristic of max_components + if not self.n_components: + components = list(range(1, min(max_components, X.size(dim=0)))) + gmms = [ + mixture.GaussianMixture(n_components=n, **gmm_kwargs) + for n in components + ] + bics = [model.fit(X).bic(X) for model in gmms] + self.n_components = components[np.argmin(bics)] + del gmms, bics, components + + # Fit GMM + gmm = mixture.GaussianMixture(n_components=self.n_components, **gmm_kwargs) + gmm.fit(X) + + # Save info from GMM into the register buffers + self.register_buffer("means", torch.from_numpy(gmm.means_)) + self.register_buffer("weights", torch.from_numpy(gmm.weights_)) + self.register_buffer("covariances", torch.from_numpy(gmm.covariances_)) + self.register_buffer( + "precisions_cholesky", torch.from_numpy(gmm.precisions_cholesky_) + ) diff --git a/nequip/utils/misc.py b/nequip/utils/misc.py index 4beba97b..34b04f7f 100644 --- a/nequip/utils/misc.py +++ b/nequip/utils/misc.py @@ -1,5 +1,34 @@ +from typing import Union +import contextlib + import torch -def dtype_from_name(name: str) -> torch.dtype: +def dtype_from_name(name: Union[str, torch.dtype]) -> torch.dtype: + if isinstance(name, torch.dtype): + return name return {"float32": torch.float32, "float64": torch.float64}[name] + + +def dtype_to_name(name: Union[str, torch.dtype]) -> torch.dtype: + if isinstance(name, str): + return name + return {torch.float32: "float32", torch.float64: "float64"}[name] + + +def get_default_device_name() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +@contextlib.contextmanager +def torch_default_dtype(dtype): + """Set `torch.get_default_dtype()` for the duration of a with block, cleaning up with a `finally`. + + Note that this is NOT thread safe, since `torch.set_default_dtype()` is not thread safe. + """ + orig_default_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(dtype) + yield + finally: + torch.set_default_dtype(orig_default_dtype) diff --git a/nequip/utils/regressor.py b/nequip/utils/regressor.py index 76d140bc..578c45f6 100644 --- a/nequip/utils/regressor.py +++ b/nequip/utils/regressor.py @@ -2,7 +2,6 @@ import torch from torch import matmul -from torch.linalg import solve, inv from typing import Optional, Sequence from opt_einsum import contract @@ -26,16 +25,20 @@ def solver(X, y, alpha: Optional[float] = 0.001, stride: Optional[int] = 1, **kw feature_rms = torch.sqrt(torch.mean(X**2, axis=0)) - alpha_mat = torch.diag(feature_rms) * alpha * alpha + alpha_mat = torch.diag(feature_rms) * (alpha * alpha) A = matmul(X.T, X) + alpha_mat dy = y - (torch.sum(X, axis=1, keepdim=True) * y_mean).reshape(y.shape) Xy = matmul(X.T, dy) - mean = solve(A, Xy) + # A is symmetric positive semidefinite <=> A=(X + alpha*I)^T (X + alpha*I), + # so we can use cholesky: + A_cholesky = torch.linalg.cholesky(A) + mean = torch.cholesky_solve(Xy.unsqueeze(-1), A_cholesky).squeeze(-1) + Ainv = torch.cholesky_inverse(A_cholesky) + del A_cholesky sigma2 = torch.var(matmul(X, mean) - dy) - Ainv = inv(A) cov = torch.sqrt(sigma2 * contract("ij,kj,kl,li->i", Ainv, X, X, Ainv)) mean = mean + y_mean.reshape([-1]) @@ -70,7 +73,10 @@ def down_sampling_by_composition( for i in range(n_types): ids = sort_by[id_start[i] : id_end[i]] for j, p in enumerate(percentage): - new_y[i * n_points + j] = torch.quantile(y[ids], p, interpolation="linear") + # it defaults to linear anyway, and `interpolation` was a 1.11 addition + # so we leave out `, interpolation="linear")` + # https://pytorch.org/docs/1.11/generated/torch.quantile.html?highlight=quantile#torch.quantile + new_y[i * n_points + j] = torch.quantile(y[ids], p) new_X[i * n_points + j] = unique_comps[i] return new_X, new_y diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 60e68730..7c0bde3f 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -4,7 +4,7 @@ from e3nn import o3 from e3nn.util.test import equivariance_error, FLOAT_TOLERANCE -from nequip.nn import GraphModuleMixin +from nequip.nn import GraphModuleMixin, GraphModel from nequip.data import ( AtomicData, AtomicDataDict, @@ -13,7 +13,8 @@ ) -PERMUTATION_FLOAT_TOLERANCE = {torch.float32: 1e-5, torch.float64: 1e-10} +# This has to be somewhat large because of float32 sum reductions over many edges/atoms +PERMUTATION_FLOAT_TOLERANCE = {torch.float32: 1e-4, torch.float64: 1e-10} # https://discuss.pytorch.org/t/how-to-quickly-inverse-a-permutation-by-using-pytorch/116205/4 @@ -43,7 +44,11 @@ def assert_permutation_equivariant( __tracebackhide__ = True if tolerance is None: - atol = PERMUTATION_FLOAT_TOLERANCE[torch.get_default_dtype()] + atol = PERMUTATION_FLOAT_TOLERANCE[ + func.model_dtype + if isinstance(func, GraphModel) + else torch.get_default_dtype() + ] else: atol = tolerance @@ -142,7 +147,7 @@ def assert_AtomicData_equivariant( AtomicData, AtomicDataDict.Type, List[Union[AtomicData, AtomicDataDict.Type]] ], permutation_tolerance: Optional[float] = None, - o3_tolerance: Optional[float] = None, + e3_tolerance: Optional[float] = None, **kwargs, ) -> str: r"""Test the rotation, translation, parity, and permutation equivariance of ``func``. @@ -182,6 +187,12 @@ def assert_AtomicData_equivariant( irreps_in.update(func.irreps_in) irreps_in = {k: v for k, v in irreps_in.items() if k in data_in[0]} irreps_out = func.irreps_out.copy() + # Remove batch-related keys from the irreps_out, if we aren't using batched inputs + irreps_out = { + k: v + for k, v in irreps_out.items() + if not (k in ("batch", "ptr") and "batch" not in data_in) + } # for certain things, we don't care what the given irreps are... # make sure that we test correctly for equivariance: for irps in (irreps_in, irreps_out): @@ -193,9 +204,9 @@ def assert_AtomicData_equivariant( if AtomicDataDict.CELL_KEY in irps: prev_cell_irps = irps[AtomicDataDict.CELL_KEY] assert prev_cell_irps is None or o3.Irreps(prev_cell_irps) == o3.Irreps( - "3x1o" + "1o" ) - # must be this to actually rotate it + # must be this to actually rotate it when flattened irps[AtomicDataDict.CELL_KEY] = "3x1o" stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY) @@ -231,7 +242,9 @@ def wrapper(*args): # we need it to be decomposed into irreps for equivar testing for k in stress_keys: if k in output: - output[k] = stress_cart_tensor.from_cartesian(output[k], rtp=stress_rtp) + output[k] = stress_cart_tensor.from_cartesian( + output[k], rtp=stress_rtp.to(output[k].dtype) + ) return [output[k] for k in irreps_out] # prepare input data @@ -257,23 +270,27 @@ def wrapper(*args): # take max across errors errs = {k: torch.max(torch.vstack([e[k] for e in errs]), dim=0)[0] for k in errs[0]} - if o3_tolerance is None: - o3_tolerance = FLOAT_TOLERANCE[torch.get_default_dtype()] + current_dtype = ( + func.model_dtype if isinstance(func, GraphModel) else torch.get_default_dtype() + ) + if e3_tolerance is None: + e3_tolerance = FLOAT_TOLERANCE[current_dtype] all_errs = [] for case, err in errs.items(): for key, this_err in zip(irreps_out.keys(), err): all_errs.append(case + (key, this_err)) - is_problem = [e[-1] > o3_tolerance for e in all_errs] + is_problem = [e[-1] > e3_tolerance for e in all_errs] message = (permutation_message + "\n") + "\n".join( - " (parity_k={:1d}, did_translate={:5}, field={:20}) -> max error={:.3e}".format( - int(k[0]), str(bool(k[1])), str(k[2]), float(k[3]) - ) + f" (parity_k={int(k[0]):1d}, did_translate={str(bool(k[1])):5}, field={str(k[2]):20}) -> max error={float(k[3]):.3e}{' FAIL' if prob else ''}" for k, prob in zip(all_errs, is_problem) + if irreps_out[str(k[2])] is not None ) - if sum(is_problem) > 0 or "FAIL" in permutation_message: - raise AssertionError(f"Equivariance test failed for cases:\n{message}") + if any(is_problem) or " FAIL" in permutation_message: + raise AssertionError( + f"Equivariance test of {type(func).__name__} failed:\n default dtype: {torch.get_default_dtype()} (assumed) model dtype: {current_dtype} E(3) tolerance: {e3_tolerance}\n{message}" + ) return message @@ -323,15 +340,13 @@ def pre_hook(mod: GraphModuleMixin, inp): ) for k, ir in mod.irreps_in.items(): if k not in inp: - raise KeyError( - f"Field {k} with irreps {ir} expected to be input to {mname}; not present" - ) + pass elif isinstance(inp[k], torch.Tensor) and isinstance(ir, o3.Irreps): - if inp[k].ndim == 1: + if inp[k].ndim == 1 and inp[k].numel() > 0: raise ValueError( f"Field {k} in input to module {mname} has only one dimension (assumed to be batch-like); it must have a second irreps dimension even if irreps.dim == 1 (i.e. a single per atom scalar must have shape [N_at, 1], not [N_at])" ) - elif inp[k].shape[-1] != ir.dim: + elif inp[k].shape[-1] != ir.dim and inp[k].numel() > 0: raise ValueError( f"Field {k} in input to module {mname} has last dimension {inp[k].shape[-1]} but its irreps {ir} indicate last dimension {ir.dim}" ) @@ -350,15 +365,13 @@ def post_hook(mod: GraphModuleMixin, _, out): ) for k, ir in mod.irreps_out.items(): if k not in out: - raise KeyError( - f"Field {k} with irreps {ir} expected to be in output from {mname}; not present" - ) + pass elif isinstance(out[k], torch.Tensor) and isinstance(ir, o3.Irreps): - if out[k].ndim == 1: + if out[k].ndim == 1 and out[k].numel() > 0: raise ValueError( f"Field {k} in output from module {mname} has only one dimension (assumed to be batch-like); it must have a second irreps dimension even if irreps.dim == 1 (i.e. a single per atom scalar must have shape [N_at, 1], not [N_at])" ) - elif out[k].shape[-1] != ir.dim: + elif out[k].shape[-1] != ir.dim and out[k].numel() > 0: raise ValueError( f"Field {k} in output from {mname} has last dimension {out[k].shape[-1]} but its irreps {ir} indicate last dimension {ir.dim}" ) diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index 4cfa98ff..a2dc103d 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -6,7 +6,7 @@ import os from ase.atoms import Atoms -from ase.build import molecule +from ase.build import molecule, bulk from ase.calculators.singlepoint import SinglePointCalculator from ase.io import write @@ -19,6 +19,25 @@ from nequip.utils._global_options import _set_global_options from nequip.utils.misc import dtype_from_name +# Sometimes we run parallel using pytest-xdist, and want to be able to use +# as many GPUs as are available +# https://pytest-xdist.readthedocs.io/en/latest/how-to.html#identifying-the-worker-process-during-a-test +_is_pytest_xdist: bool = os.environ.get("PYTEST_XDIST_WORKER", "master") != "master" +if _is_pytest_xdist and torch.cuda.is_available(): + _xdist_worker_rank: int = int(os.environ["PYTEST_XDIST_WORKER"].lstrip("gw")) + _cuda_vis_devs = os.environ.get( + "CUDA_VISIBLE_DEVICES", + ",".join(str(e) for e in range(torch.cuda.device_count())), + ).split(",") + _cuda_vis_devs = [int(e) for e in _cuda_vis_devs] + # set this for tests that run in this process + _local_gpu_rank = _xdist_worker_rank % torch.cuda.device_count() + torch.cuda.set_device(_local_gpu_rank) + # set this for launched child processes + os.environ["CUDA_VISIBLE_DEVICES"] = str(_cuda_vis_devs[_local_gpu_rank]) + del _xdist_worker_rank, _cuda_vis_devs, _local_gpu_rank + + if "NEQUIP_NUM_TASKS" not in os.environ: # Test parallelization, but don't waste time spawning tons of workers if lots of cores available os.environ["NEQUIP_NUM_TASKS"] = "2" @@ -96,6 +115,16 @@ def CH3CHO_no_typemap(float_tolerance) -> Tuple[Atoms, AtomicData]: return atoms, data +@pytest.fixture(scope="session") +def Cu_bulk(float_tolerance) -> Tuple[Atoms, AtomicData]: + atoms = bulk("Cu") * (2, 2, 1) + atoms.rattle() + data = AtomicData.from_ase(atoms, r_max=3.5) + tm = TypeMapper(chemical_symbol_to_type={"Cu": 0}) + data = tm(data) + return atoms, data + + @pytest.fixture(scope="session") def molecules() -> List[Atoms]: atoms_list = [] @@ -121,7 +150,7 @@ def nequip_dataset(molecules, temp_data, float_tolerance): a = ASEDataset( file_name=fp.name, root=temp_data, - extra_fixed_fields={"r_max": 3.0}, + AtomicData_options={"r_max": 3.0}, ase_args=dict(format="extxyz"), type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), ) diff --git a/nequip/utils/unittests/model_tests.py b/nequip/utils/unittests/model_tests.py index 2b6a8b63..37e9dcb6 100644 --- a/nequip/utils/unittests/model_tests.py +++ b/nequip/utils/unittests/model_tests.py @@ -19,6 +19,7 @@ from nequip.data.transforms import TypeMapper from nequip.model import model_from_config from nequip.nn import GraphModuleMixin +from nequip.utils import Config from nequip.utils.test import assert_AtomicData_equivariant @@ -55,12 +56,14 @@ def make_model(config, device, initialize: bool = True, deploy: bool = False): "types_names": ["H", "C", "O"], } ) - model = model_from_config(config, initialize=initialize, deploy=deploy) + model = model_from_config( + Config.from_dict(config), initialize=initialize, deploy=deploy + ) model = model.to(device) return model @pytest.fixture(scope="class") - def model(self, config, device): + def model(self, config, device, float_tolerance): config, out_fields = config model = self.make_model(config, device=device) return model, out_fields @@ -76,12 +79,22 @@ def test_jit(self, model, atomic_batch, device): instance = instance.to(device=device) model_script = script(instance) + atol = { + # tight, but not that tight, since GPU nondet has to pass + # plus model insides are still float32 with global dtype float64 in the tests + torch.float32: 5e-5, + torch.float64: 5e-7, + }[torch.get_default_dtype()] + + out_instance = instance(data.copy()) + out_script = model_script(data.copy()) + for out_field in out_fields: assert torch.allclose( - instance(data)[out_field], - model_script(data)[out_field], - atol=1e-6, - ) + out_instance[out_field], + out_script[out_field], + atol=atol, + ), f"JIT didn't repro non-JIT on field {out_field} with max error {(out_instance[out_field] - out_script[out_field]).abs().max().item()}" # - Try saving, loading in another process, and running - with tempfile.TemporaryDirectory() as tmpdir: @@ -94,18 +107,15 @@ def test_jit(self, model, atomic_batch, device): load_model = torch.jit.load(tmpdir + "/model.pt") load_dat = torch.load(tmpdir + "/dat.pt") - atol = { - # tight, but not that tight, since GPU nondet has to pass - torch.float32: 1e-6, - torch.float64: 1e-10, - }[torch.get_default_dtype()] + out_script = model_script(data.copy()) + out_load = load_model(load_dat.copy()) for out_field in out_fields: assert torch.allclose( - model_script(data)[out_field], - load_model(load_dat)[out_field], + out_script[out_field], + out_load[out_field], atol=atol, - ) + ), f"JIT didn't repro save-and-loaded JIT on field {out_field} with max error {(out_script[out_field] - out_load[out_field]).abs().max().item()}" def test_forward(self, model, atomic_batch, device): instance, out_fields = model @@ -115,6 +125,53 @@ def test_forward(self, model, atomic_batch, device): for out_field in out_fields: assert out_field in output + def test_wrapped_unwrapped(self, model, device, Cu_bulk, float_tolerance): + atoms, data_orig = Cu_bulk + instance, out_fields = model + data = AtomicData.from_ase(atoms, r_max=3.5) + data[AtomicDataDict.ATOM_TYPE_KEY] = data_orig[AtomicDataDict.ATOM_TYPE_KEY] + data.to(device) + out_ref = instance(AtomicData.to_AtomicDataDict(data)) + # now put things in other periodic images + rng = torch.Generator(device=device).manual_seed(12345) + # try a few different shifts + for _ in range(3): + cell_shifts = torch.randint( + -5, + 5, + (len(atoms), 3), + device=device, + dtype=data[AtomicDataDict.POSITIONS_KEY].dtype, + generator=rng, + ) + shifts = torch.einsum( + "zi,ix->zx", cell_shifts, data[AtomicDataDict.CELL_KEY] + ) + atoms2 = atoms.copy() + atoms2.positions += shifts.detach().cpu().numpy() + # must recompute the neighborlist for this, since the edge_cell_shifts changed + data2 = AtomicData.from_ase(atoms2, r_max=3.5) + data2[AtomicDataDict.ATOM_TYPE_KEY] = data[AtomicDataDict.ATOM_TYPE_KEY] + data2.to(device) + assert torch.equal( + data[AtomicDataDict.EDGE_INDEX_KEY], + data2[AtomicDataDict.EDGE_INDEX_KEY], + ) + tmp = ( + data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] + + cell_shifts[data[AtomicDataDict.EDGE_INDEX_KEY][0]] + - cell_shifts[data[AtomicDataDict.EDGE_INDEX_KEY][1]] + ) + assert torch.equal( + tmp, + data2[AtomicDataDict.EDGE_CELL_SHIFT_KEY], + ) + out_unwrapped = instance(AtomicData.to_AtomicDataDict(data2)) + for out_field in out_fields: + assert torch.allclose( + out_ref[out_field], out_unwrapped[out_field], atol=float_tolerance + ) + def test_batch(self, model, atomic_batch, device, float_tolerance): """Confirm that the results for individual examples are the same regardless of whether they are batched.""" allclose = functools.partial(torch.allclose, atol=float_tolerance) @@ -175,6 +232,18 @@ def test_equivariance(self, model, atomic_batch, device): def test_embedding_cutoff(self, model, config, device): instance, out_fields = model + + # make all weights nonzero in order to have the most robust test + # default init weights can sometimes be zero (e.g. biases) but we want + # to ensure smoothness for nonzero values + # assumes any trainable parameter will be trained and thus that + # nonzero values are valid + with torch.no_grad(): + all_params = list(instance.parameters()) + old_state = [p.detach().clone() for p in all_params] + for p in all_params: + p.uniform_(-1.0, 1.0) + config, out_fields = config r_max = config["r_max"] @@ -188,8 +257,10 @@ def test_embedding_cutoff(self, model, config, device): edge_embed = instance(AtomicData.to_AtomicDataDict(data)) if AtomicDataDict.EDGE_FEATURES_KEY in edge_embed: key = AtomicDataDict.EDGE_FEATURES_KEY - else: + elif AtomicDataDict.EDGE_EMBEDDING_KEY in edge_embed: key = AtomicDataDict.EDGE_EMBEDDING_KEY + else: + pytest.skip() edge_embed = edge_embed[key] data.pos[2, 1] = r_max # put it past the cutoff edge_embed2 = instance(AtomicData.to_AtomicDataDict(data))[key] @@ -199,7 +270,9 @@ def test_embedding_cutoff(self, model, config, device): # For example, an Allegro edge feature is many body so will be affected assert torch.allclose(edge_embed[:2], edge_embed2[:2]) assert edge_embed[2:].abs().sum() > 1e-6 # some nonzero terms - assert torch.allclose(edge_embed2[2:], torch.zeros(1, device=device)) + assert torch.allclose( + edge_embed2[2:], torch.zeros(1, device=device, dtype=edge_embed2.dtype) + ) # test gradients in_dict = AtomicData.to_AtomicDataDict(data) @@ -214,7 +287,9 @@ def test_embedding_cutoff(self, model, config, device): inputs=in_dict[AtomicDataDict.POSITIONS_KEY], retain_graph=True, )[0] - assert torch.allclose(grads, torch.zeros(1, device=device)) + assert torch.allclose( + grads, torch.zeros(1, device=device, dtype=grads.dtype) + ) if AtomicDataDict.PER_ATOM_ENERGY_KEY in out: # are the first two atom's energies unaffected by atom at the cutoff? @@ -227,6 +302,11 @@ def test_embedding_cutoff(self, model, config, device): assert grads.shape == (3, 3) assert torch.allclose(grads[2], torch.zeros(1, device=device)) + # restore previous model state + with torch.no_grad(): + for p, v in zip(all_params, old_state): + p.copy_(v) + class BaseEnergyModelTests(BaseModelTests): def test_large_separation(self, model, config, molecules, device): @@ -261,6 +341,16 @@ def test_large_separation(self, model, config, molecules, device): out_both[AtomicDataDict.TOTAL_ENERGY_KEY], atol=atol, ) + if AtomicDataDict.FORCE_KEY in out1: + # check forces if it's a force model + assert torch.allclose( + torch.cat( + (out1[AtomicDataDict.FORCE_KEY], out2[AtomicDataDict.FORCE_KEY]), + dim=0, + ), + out_both[AtomicDataDict.FORCE_KEY], + atol=atol, + ) atoms_both2 = atoms1.copy() atoms3 = atoms2.copy() @@ -359,7 +449,10 @@ def test_partial_forces(self, config, atomic_batch, device, strict_locality): assert torch.allclose( output[k], output_partial[k], - atol=1e-8 if k == AtomicDataDict.TOTAL_ENERGY_KEY else 1e-6, + atol=1e-8 + if k == AtomicDataDict.TOTAL_ENERGY_KEY + and torch.get_default_dtype() == torch.float64 + else 1e-5, ) else: assert torch.equal(output[k], output_partial[k]) @@ -381,4 +474,40 @@ def test_partial_forces(self, config, atomic_batch, device, strict_locality): adjacency = data[AtomicDataDict.BATCH_KEY].view(-1, 1) == data[ AtomicDataDict.BATCH_KEY ].view(1, -1) - assert torch.equal(adjacency, torch.any(partial_forces != 0, dim=-1)) + # for non-adjacent atoms, all partial forces must be zero + assert torch.all(partial_forces[~adjacency] == 0) + + def test_force_smoothness(self, model, config, device): + instance, out_fields = model + if AtomicDataDict.FORCE_KEY not in out_fields: + pytest.skip() + # see test_embedding_cutoff + with torch.no_grad(): + all_params = list(instance.parameters()) + old_state = [p.detach().clone() for p in all_params] + for p in all_params: + p.uniform_(-3.0, 3.0) + config, out_fields = config + r_max = config["r_max"] + + # make a synthetic three atom example + data = AtomicData( + atom_types=np.random.choice([0, 1, 2], size=3), + pos=np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [r_max, 0.0, 0.0]]), + edge_index=np.array([[0, 1, 0, 2], [1, 0, 2, 0]]), + ) + data = data.to(device) + out = instance(AtomicData.to_AtomicDataDict(data)) + forces = out[AtomicDataDict.FORCE_KEY] + assert ( + forces[:2].abs().sum() > 1e-4 + ) # some nonzero terms on the two connected atoms + assert torch.allclose( + forces[2], + torch.zeros(1, device=device, dtype=forces.dtype), + ) # the atom at the cutoff should be zero + + # restore previous model state + with torch.no_grad(): + for p, v in zip(all_params, old_state): + p.copy_(v) diff --git a/nequip/utils/versions.py b/nequip/utils/versions.py index db35a451..3c733c65 100644 --- a/nequip/utils/versions.py +++ b/nequip/utils/versions.py @@ -1,4 +1,5 @@ -from typing import Tuple +from typing import Tuple, Final +import packaging.version import logging @@ -8,6 +9,10 @@ from .git import get_commit +_TORCH_IS_GE_1_13: Final[bool] = packaging.version.parse( + torch.__version__ +) >= packaging.version.parse("1.13.0") + _DEFAULT_VERSION_CODES = [torch, e3nn, nequip] _DEFAULT_COMMIT_CODES = ["e3nn", "nequip"] diff --git a/nequip/utils/wandb.py b/nequip/utils/wandb.py index 2391a9f4..7f0d5e10 100644 --- a/nequip/utils/wandb.py +++ b/nequip/utils/wandb.py @@ -1,19 +1,25 @@ -import wandb import logging +import secrets + +from nequip.utils import Config + +import wandb from wandb.util import json_friendly_val def init_n_update(config): - conf_dict = dict(config) + conf_dict = Config.as_dict(config) # wandb mangles keys (in terms of type) as well, but we can't easily correct that because there are many ambiguous edge cases. (E.g. string "-1" vs int -1 as keys, are they different config keys?) if any(not isinstance(k, str) for k in conf_dict.keys()): raise TypeError( "Due to wandb limitations, only string keys are supported in configurations." ) - # download from wandb set up - config.run_id = wandb.util.generate_id() + # create a run id + # see https://github.com/wandb/wandb/pull/4676 + config.run_id = secrets.token_urlsafe() + # download from wandb set up wandb.init( project=config.wandb_project, config=conf_dict, @@ -27,7 +33,9 @@ def init_n_update(config): skip = False if k in config.keys(): # double check the one sanitized by wandb - v_old = json_friendly_val(config[k]) + # because we're preprocessing the config and looping over + # _every_ key, don't mark accessed keys as valid => _get_nomark + v_old = json_friendly_val(config._get_nomark(k)) if repr(v_new) == repr(v_old): skip = True if skip: diff --git a/setup.py b/setup.py index d7a5b465..6ca9e3cf 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,6 @@ "numpy", "ase", "tqdm", - "torch>=1.10.0,<1.13,!=1.9.0", "e3nn>=0.4.4,<0.6.0", "pyyaml", "contextlib2;python_version<'3.7'", # backport of nullcontext diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..ceb840b7 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,182 @@ +import pytest +import tempfile +import pathlib +import yaml +import subprocess +import os +import sys + +import torch + +from nequip.data import AtomicDataDict +from nequip.nn import GraphModuleMixin + + +def _check_and_print(retcode): + __tracebackhide__ = True + if retcode.returncode: + if retcode.stdout is not None and len(retcode.stdout) > 0: + print(retcode.stdout.decode("ascii")) + if retcode.stderr is not None and len(retcode.stderr) > 0: + print(retcode.stderr.decode("ascii"), file=sys.stderr) + retcode.check_returncode() + + +class IdentityModel(GraphModuleMixin, torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + self._init_irreps( + irreps_in={ + AtomicDataDict.TOTAL_ENERGY_KEY: "0e", + AtomicDataDict.FORCE_KEY: "1o", + }, + ) + self.zero = torch.nn.Parameter(torch.as_tensor(0.0)) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + err = self.zero + data[AtomicDataDict.FORCE_KEY] = data[AtomicDataDict.FORCE_KEY] + err + data[AtomicDataDict.NODE_FEATURES_KEY] = ( + 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() + ) # some BS + data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( + data[AtomicDataDict.TOTAL_ENERGY_KEY] + err + ) + return data + + +class ConstFactorModel(GraphModuleMixin, torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + self._init_irreps( + irreps_in={ + AtomicDataDict.TOTAL_ENERGY_KEY: "0e", + AtomicDataDict.FORCE_KEY: "1o", + }, + ) + # to keep the optimizer happy: + self.dummy = torch.nn.Parameter(torch.zeros(1)) + self.register_buffer("factor", 3.7777 * torch.randn(1).squeeze()) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data[AtomicDataDict.FORCE_KEY] = ( + self.factor * data[AtomicDataDict.FORCE_KEY] + 0.0 * self.dummy + ) + data[AtomicDataDict.NODE_FEATURES_KEY] = ( + 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() + ) # some BS + data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( + self.factor * data[AtomicDataDict.TOTAL_ENERGY_KEY] + 0.0 * self.dummy + ) + return data + + +class LearningFactorModel(GraphModuleMixin, torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + self._init_irreps( + irreps_in={ + AtomicDataDict.TOTAL_ENERGY_KEY: "0e", + AtomicDataDict.FORCE_KEY: "1o", + }, + ) + # By using a big factor, we keep it in a nice descending part + # of the optimization without too much oscilation in loss at + # the beginning + self.factor = torch.nn.Parameter(torch.as_tensor(1.111)) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data[AtomicDataDict.FORCE_KEY] = self.factor * data[AtomicDataDict.FORCE_KEY] + data[AtomicDataDict.NODE_FEATURES_KEY] = ( + 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() + ) # some BS + data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( + self.factor * data[AtomicDataDict.TOTAL_ENERGY_KEY] + ) + return data + + +def _training_session(conffile, model_dtype, builder, BENCHMARK_ROOT): + default_dtype = str(torch.get_default_dtype())[len("torch.") :] + if default_dtype == "float32" and model_dtype == "float64": + pytest.skip("default_dtype=float32 and model_dtype=float64 doesn't make sense") + + path_to_this_file = pathlib.Path(__file__) + config_path = path_to_this_file.parents[2] / f"configs/{conffile}" + true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save time + run_name = "test_train_" + default_dtype + true_config["run_name"] = run_name + true_config["root"] = "./" + true_config["dataset_file_name"] = str( + BENCHMARK_ROOT / "aspirin_ccsd-train.npz" + ) + true_config["default_dtype"] = default_dtype + true_config["model_dtype"] = model_dtype + true_config["max_epochs"] = 2 + true_config["model_builders"] = [builder] + # just do forces, which is what the mock models have: + true_config["loss_coeffs"] = "forces" + # We need truth labels as inputs for these fake testing models + true_config["model_input_fields"] = { + AtomicDataDict.FORCE_KEY: "1o", + AtomicDataDict.TOTAL_ENERGY_KEY: "0e", + } + + config_path = tmpdir + "/conf.yaml" + with open(config_path, "w+") as fp: + yaml.dump(true_config, fp) + # == Train model == + env = dict(os.environ) + # make this script available so model builders can be loaded + env["PYTHONPATH"] = ":".join( + [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") + ) + + retcode = subprocess.run( + # we use --warn-unused because we are using configs with many unused keys for testing + ["nequip-train", "conf.yaml", "--warn-unused"], + cwd=tmpdir, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _check_and_print(retcode) + + yield true_config, tmpdir, env + + +@pytest.fixture( + scope="session", + params=[ + ("minimal.yaml", AtomicDataDict.FORCE_KEY), + ("minimal_toy_emt.yaml", AtomicDataDict.STRESS_KEY), + ], +) +def conffile(request): + return request.param + + +@pytest.fixture( + scope="session", + params=["float32", "float64"], +) +def model_dtype(request, float_tolerance): + if torch.get_default_dtype() == torch.float32 and model_dtype == "float64": + pytest.skip("default_dtype=float32 and model_dtype=float64 doesn't make sense") + return request.param + + +@pytest.fixture( + scope="session", params=[ConstFactorModel, LearningFactorModel, IdentityModel] +) +def fake_model_training_session(request, BENCHMARK_ROOT, conffile, model_dtype): + conffile, _ = conffile + builder = request.param + + session = _training_session(conffile, model_dtype, builder, BENCHMARK_ROOT) + true_config, tmpdir, env = next(session) + yield builder, true_config, tmpdir, env + del session diff --git a/tests/integration/test_deploy.py b/tests/integration/test_deploy.py index cc710f11..51d6b936 100644 --- a/tests/integration/test_deploy.py +++ b/tests/integration/test_deploy.py @@ -15,17 +15,19 @@ from nequip.train import Trainer from nequip.ase import NequIPCalculator +from conftest import _check_and_print + @pytest.mark.parametrize( "device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) ) -def test_deploy(BENCHMARK_ROOT, device): +@pytest.mark.parametrize("model_dtype", ["float32", "float64"]) +def test_deploy(BENCHMARK_ROOT, device, model_dtype): dtype = str(torch.get_default_dtype())[len("torch.") :] - atol = {"float32": 1e-5, "float64": 1e-7}[dtype] - - # if torch.cuda.is_available(): - # # TODO: is this true? - # pytest.skip("CUDA and subprocesses have issues") + if dtype == "float32" and model_dtype == "float64": + pytest.skip("default_dtype=float32 and model_dtype=float64 doesn't make sense") + # atol on MODEL dtype, since a mostly float32 model still has float32 variation + atol = {"float32": 1e-5, "float64": 1e-7}[model_dtype] keys = [ AtomicDataDict.TOTAL_ENERGY_KEY, @@ -45,8 +47,9 @@ def test_deploy(BENCHMARK_ROOT, device): BENCHMARK_ROOT / "aspirin_ccsd-train.npz" ) true_config["default_dtype"] = dtype + true_config["model_dtype"] = model_dtype true_config["max_epochs"] = 1 - true_config["n_train"] = 1 + true_config["n_train"] = 2 true_config["n_val"] = 1 config_path = "conf.yaml" full_config_path = f"{tmpdir}/{config_path}" @@ -54,7 +57,7 @@ def test_deploy(BENCHMARK_ROOT, device): yaml.dump(true_config, fp) # Train model retcode = subprocess.run(["nequip-train", str(config_path)], cwd=tmpdir) - retcode.check_returncode() + _check_and_print(retcode) # Deploy deployed_path = pathlib.Path(f"deployed_{dtype}.pth") retcode = subprocess.run( @@ -67,12 +70,12 @@ def test_deploy(BENCHMARK_ROOT, device): ], cwd=tmpdir, ) - retcode.check_returncode() + _check_and_print(retcode) deployed_path = tmpdir / deployed_path assert deployed_path.is_file(), "Deploy didn't create file" # now test predictions the same - best_mod, _ = Trainer.load_model_from_training_session( + best_mod, train_config = Trainer.load_model_from_training_session( traindir=f"{root}/{run_name}/", model_name="best_model.pth", device=device, @@ -119,7 +122,7 @@ def test_deploy(BENCHMARK_ROOT, device): stdout=subprocess.PIPE, **text, ) - retcode.check_returncode() + _check_and_print(retcode) # Try to load extract config config = yaml.load(retcode.stdout, Loader=yaml.Loader) del config diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index 2bec1215..4dd9bce0 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -1,9 +1,5 @@ import pytest -import tempfile -import pathlib -import yaml import subprocess -import os import textwrap import shutil @@ -14,72 +10,22 @@ from nequip.data import AtomicDataDict -from test_train import ConstFactorModel, IdentityModel # noqa - - -@pytest.fixture( - scope="module", - params=[ - ("minimal.yaml", AtomicDataDict.FORCE_KEY), - ], -) -def conffile(request): - return request.param - - -@pytest.fixture(scope="module", params=[ConstFactorModel, IdentityModel]) -def training_session(request, BENCHMARK_ROOT, conffile): - conffile, _ = conffile - builder = request.param - dtype = str(torch.get_default_dtype())[len("torch.") :] - - # if torch.cuda.is_available(): - # # TODO: is this true? - # pytest.skip("CUDA and subprocesses have issues") - - path_to_this_file = pathlib.Path(__file__) - config_path = path_to_this_file.parents[2] / f"configs/{conffile}" - true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) - with tempfile.TemporaryDirectory() as tmpdir: - # == Run training == - # Save time - run_name = "test_train_" + dtype - true_config["run_name"] = run_name - true_config["root"] = "./" - true_config["dataset_file_name"] = str( - BENCHMARK_ROOT / "aspirin_ccsd-train.npz" - ) - true_config["default_dtype"] = dtype - true_config["max_epochs"] = 2 - true_config["model_builders"] = [builder] - # We need truth labels as inputs for these fake testing models - true_config["_override_allow_truth_label_inputs"] = True - - # to be a true identity, we can't have rescaling - true_config["global_rescale_shift"] = None - true_config["global_rescale_scale"] = None - - config_path = tmpdir + "/conf.yaml" - with open(config_path, "w+") as fp: - yaml.dump(true_config, fp) - # == Train model == - env = dict(os.environ) - # make this script available so model builders can be loaded - env["PYTHONPATH"] = ":".join( - [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") - ) - retcode = subprocess.run(["nequip-train", "conf.yaml"], cwd=tmpdir, env=env) - retcode.check_returncode() - - yield builder, true_config, tmpdir, env +from conftest import IdentityModel, ConstFactorModel, _check_and_print @pytest.mark.parametrize("do_test_idcs", [True, False]) @pytest.mark.parametrize("do_metrics", [True, False]) @pytest.mark.parametrize("do_output_fields", [True, False]) -def test_metrics(training_session, do_test_idcs, do_metrics, do_output_fields): - - builder, true_config, tmpdir, env = training_session +def test_metrics( + fake_model_training_session, conffile, do_test_idcs, do_metrics, do_output_fields +): + energy_only: bool = conffile[0] == "minimal_eng.yaml" + if energy_only: + # By default, don't run the energy only tests... they are redundant and add a _lot_ of expense + pytest.skip() + builder, true_config, tmpdir, env = fake_model_training_session + if builder not in (IdentityModel, ConstFactorModel): + pytest.skip() # == Run test error == outdir = f"{true_config['root']}/{true_config['run_name']}/" @@ -105,7 +51,7 @@ def runit(params: dict): stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - retcode.check_returncode() + _check_and_print(retcode) # Check the output metrics = dict( @@ -119,9 +65,16 @@ def runit(params: dict): # Test idcs if do_test_idcs: - # The Aspirin dataset is 1000 frames long - # Pick some arbitrary number of frames - test_idcs_arr = torch.randperm(1000)[:257] + if conffile[0] == "minimal.yaml": + # The Aspirin dataset is 1000 frames long + # Pick some arbitrary number of frames + test_idcs_arr = torch.randperm(1000)[:257] + elif conffile[0] == "minimal_toy_emt.yaml": + # The toy EMT dataset is 50 frames long + # Pick some arbitrary number of frames + test_idcs_arr = torch.randperm(50)[:7] + else: + raise KeyError test_idcs = "some-test-idcs.pth" torch.save(test_idcs_arr, f"{tmpdir}/{test_idcs}") else: @@ -134,39 +87,64 @@ def runit(params: dict): metrics_yaml = "my-metrics.yaml" with open(f"{tmpdir}/{metrics_yaml}", "w") as f: # Write out a fancier metrics file - f.write( - textwrap.dedent( - """ - metrics_components: - - - forces - - rmse - - report_per_component: True - - - forces - - mae - - PerSpecies: True - - - total_energy - - mae - - - total_energy - - mae - - PerAtom: True - """ + if energy_only: + f.write( + textwrap.dedent( + """ + metrics_components: + - - total_energy + - mae + - - total_energy + - mae + - PerAtom: True + """ + ) + ) + expect_metrics = { + "e_mae", + "e/N_mae", + } + else: + # Write out a fancier metrics file + f.write( + textwrap.dedent( + """ + metrics_components: + - - forces + - rmse + - report_per_component: True + - - forces + - mae + - PerSpecies: True + - - total_energy + - mae + - - total_energy + - mae + - PerAtom: True + """ + ) + ) + expect_metrics = { + "f_rmse_0", + "f_rmse_1", + "f_rmse_2", + "psavg_f_mae", + "e_mae", + "e/N_mae", + }.union( + { + # For the PerSpecies + sym + "_f_mae" + for sym in true_config["chemical_symbols"] + } ) - ) - expect_metrics = { - "f_rmse_0", - "f_rmse_1", - "f_rmse_2", - "H_f_mae", - "C_f_mae", - "O_f_mae", - "psavg_f_mae", - "e_mae", - "e/N_mae", - } else: metrics_yaml = None # Regardless of builder, with minimal.yaml, we should have RMSE and MAE - expect_metrics = {"f_mae", "f_rmse"} + if energy_only: + expect_metrics = {"e_mae", "e_rmse"} + else: + expect_metrics = {"f_mae", "f_rmse"} default_params["metrics-config"] = metrics_yaml if do_output_fields: @@ -187,8 +165,16 @@ def runit(params: dict): # check metrics if builder == IdentityModel: + true_identity: bool = true_config["default_dtype"] == true_config["model_dtype"] for metric, err in metrics.items(): - assert np.allclose(err, 0.0), f"Metric `{metric}` wasn't zero!" + # see test_train.py for discussion + assert np.allclose( + err, + 0.0, + atol=1e-8 + if true_identity + else (1e-2 if metric.startswith("e") else 1e-4), + ), f"Metric `{metric}` wasn't zero!" elif builder == ConstFactorModel: # TODO: check comperable to naive numpy compute pass diff --git a/tests/integration/test_train.py b/tests/integration/test_train.py index 36597a98..b9935b3c 100644 --- a/tests/integration/test_train.py +++ b/tests/integration/test_train.py @@ -9,200 +9,107 @@ import torch from nequip.data import AtomicDataDict -from nequip.nn import GraphModuleMixin - -class IdentityModel(GraphModuleMixin, torch.nn.Module): - def __init__(self, **kwargs): - super().__init__() - self._init_irreps( - irreps_in={ - AtomicDataDict.TOTAL_ENERGY_KEY: "0e", - AtomicDataDict.FORCE_KEY: "1o", - }, - ) - self.one = torch.nn.Parameter(torch.as_tensor(1.0)) - - def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - data[AtomicDataDict.FORCE_KEY] = self.one * data[AtomicDataDict.FORCE_KEY] - data[AtomicDataDict.NODE_FEATURES_KEY] = ( - 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() - ) # some BS - data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( - self.one * data[AtomicDataDict.TOTAL_ENERGY_KEY] - ) - return data - - -class ConstFactorModel(GraphModuleMixin, torch.nn.Module): - def __init__(self, **kwargs): - super().__init__() - self._init_irreps( - irreps_in={ - AtomicDataDict.TOTAL_ENERGY_KEY: "0e", - AtomicDataDict.FORCE_KEY: "1o", - }, - ) - # to keep the optimizer happy: - self.dummy = torch.nn.Parameter(torch.zeros(1)) - self.register_buffer("factor", 3.7777 * torch.randn(1).squeeze()) - - def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - data[AtomicDataDict.FORCE_KEY] = ( - self.factor * data[AtomicDataDict.FORCE_KEY] + 0.0 * self.dummy - ) - data[AtomicDataDict.NODE_FEATURES_KEY] = ( - 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() - ) # some BS - data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( - self.factor * data[AtomicDataDict.TOTAL_ENERGY_KEY] + 0.0 * self.dummy - ) - return data - - -class LearningFactorModel(GraphModuleMixin, torch.nn.Module): - def __init__(self, **kwargs): - super().__init__() - self._init_irreps( - irreps_in={ - AtomicDataDict.TOTAL_ENERGY_KEY: "0e", - AtomicDataDict.FORCE_KEY: "1o", - }, - ) - # By using a big factor, we keep it in a nice descending part - # of the optimization without too much oscilation in loss at - # the beginning - self.factor = torch.nn.Parameter(torch.as_tensor(1.111)) - - def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - data[AtomicDataDict.FORCE_KEY] = self.factor * data[AtomicDataDict.FORCE_KEY] - data[AtomicDataDict.NODE_FEATURES_KEY] = ( - 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() - ) # some BS - data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( - self.factor * data[AtomicDataDict.TOTAL_ENERGY_KEY] - ) - return data - - -@pytest.mark.parametrize( - "conffile", - [ - "minimal.yaml", - "minimal_eng.yaml", - ], +from conftest import ( + IdentityModel, + ConstFactorModel, + LearningFactorModel, + _check_and_print, ) -@pytest.mark.parametrize( - "builder", [IdentityModel, ConstFactorModel, LearningFactorModel] -) -def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, builder): - - dtype = str(torch.get_default_dtype())[len("torch.") :] - - # if torch.cuda.is_available(): - # # TODO: is this true? - # pytest.skip("CUDA and subprocesses have issues") - - path_to_this_file = pathlib.Path(__file__) - config_path = path_to_this_file.parents[2] / f"configs/{conffile}" - true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) - - with tempfile.TemporaryDirectory() as tmpdir: - # Save time - run_name = "test_train_" + dtype - true_config["run_name"] = run_name - true_config["root"] = "./" - true_config["dataset_file_name"] = str( - BENCHMARK_ROOT / "aspirin_ccsd-train.npz" - ) - true_config["default_dtype"] = dtype - true_config["max_epochs"] = 2 - # We just don't add rescaling: - true_config["model_builders"] = [builder] - # We need truth labels as inputs for these fake testing models - true_config["_override_allow_truth_label_inputs"] = True - config_path = tmpdir + "/conf.yaml" - with open(config_path, "w+") as fp: - yaml.dump(true_config, fp) - # == Train model == - env = dict(os.environ) - # make this script available so model builders can be loaded - env["PYTHONPATH"] = ":".join( - [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") - ) - retcode = subprocess.run( - ["nequip-train", "conf.yaml"], - cwd=tmpdir, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - retcode.check_returncode() +def test_metrics(fake_model_training_session, model_dtype): + default_dtype = str(torch.get_default_dtype()).lstrip("torch.") + builder, true_config, tmpdir, env = fake_model_training_session - # == Load metrics == - outdir = f"{tmpdir}/{true_config['root']}/{run_name}/" + # == Load metrics == + outdir = f"{tmpdir}/{true_config['root']}/{true_config['run_name']}/" - if builder == IdentityModel or builder == LearningFactorModel: - for which in ("train", "val"): + if builder == IdentityModel or builder == LearningFactorModel: + for which in ("train", "val"): - dat = np.genfromtxt( - f"{outdir}/metrics_batch_{which}.csv", - delimiter=",", - names=True, - dtype=None, - ) - for field in dat.dtype.names: - if field == "epoch" or field == "batch": - continue - # Everything else should be a loss or a metric - if builder == IdentityModel: + dat = np.genfromtxt( + f"{outdir}/metrics_batch_{which}.csv", + delimiter=",", + names=True, + dtype=None, + ) + for field in dat.dtype.names: + if field == "epoch" or field == "batch": + continue + # Everything else should be a loss or a metric + if builder == IdentityModel: + if model_dtype == default_dtype: + # We have a true identity model assert np.allclose( - dat[field], 0.0 + dat[field], + 0.0, + atol=1e-6 if default_dtype == "float32" else 1e-9, ), f"Loss/metric `{field}` wasn't all zeros for {which}" - elif builder == LearningFactorModel: - assert ( - dat[field][-1] < dat[field][0] - ), f"Loss/metric `{field}` didn't go down for {which}" - - # epoch metrics - dat = np.genfromtxt( - f"{outdir}/metrics_epoch.csv", - delimiter=",", - names=True, - dtype=None, - ) - for field in dat.dtype.names: - if field == "epoch" or field == "wall" or field == "LR": - continue - - # Everything else should be a loss or a metric - if builder == IdentityModel: + else: + # we have an approximate identity model that applies a floating point truncation + # in the actual aspirin test data used here, the truncation error is maximally 0.0155 + # there is also no rescaling so everything is in real units here + assert np.all( + dat[field] < 0.02 + ), f"Loss/metric `{field}` wasn't approximately zeros for {which}" + elif builder == LearningFactorModel: + assert ( + dat[field][-1] < dat[field][0] + ), f"Loss/metric `{field}` didn't go down for {which}" + + # epoch metrics + dat = np.genfromtxt( + f"{outdir}/metrics_epoch.csv", + delimiter=",", + names=True, + dtype=None, + ) + for field in dat.dtype.names: + if field == "epoch" or field == "wall" or field == "LR": + continue + + # Everything else should be a loss or a metric + if builder == IdentityModel: + if model_dtype == default_dtype: + # we have a true identity model assert np.allclose( - dat[field][1:], 0.0 + dat[field][1:], + 0.0, + atol=1e-6 if default_dtype == "float32" else 1e-9, ), f"Loss/metric `{field}` wasn't all equal to zero for epoch" - elif builder == ConstFactorModel: - # otherwise just check its constant. - # epoch-wise numbers should be the same, since there's no randomness at this level - assert np.allclose( - dat[field], dat[field][0] - ), f"Loss/metric `{field}` wasn't all equal to {dat[field][0]} for epoch" - elif builder == LearningFactorModel: - assert ( - dat[field][-1] < dat[field][0] - ), f"Loss/metric `{field}` didn't go down across epochs" - - # == Check model == - model = torch.load(outdir + "/last_model.pth") - - if builder == IdentityModel: - one = model["one"] - # Since the loss is always zero, even though the constant - # 1 was trainable, it shouldn't have changed - assert torch.allclose( - one, torch.ones(1, device=one.device, dtype=one.dtype) - ) + else: + # we have an approximate identity model that applies a floating point truncation + # see above + assert np.all( + dat[field][1:] < 0.02 + ), f"Loss/metric `{field}` wasn't approximately zeros for {which}" + elif builder == ConstFactorModel: + # otherwise just check its constant. + # epoch-wise numbers should be the same, since there's no randomness at this level + assert np.allclose( + dat[field], dat[field][0] + ), f"Loss/metric `{field}` wasn't all equal to {dat[field][0]} for epoch" + elif builder == LearningFactorModel: + assert ( + dat[field][-1] < dat[field][0] + ), f"Loss/metric `{field}` didn't go down across epochs" + + # == Check model == + model = torch.load(outdir + "/last_model.pth") + + if builder == IdentityModel: + # GraphModel.IdentityModel + zero = model["model.zero"] + # Since the loss is always zero, even though the constant + # 1 was trainable, it shouldn't have changed + # the tolerances when loss is nonzero are large-ish because the default learning rate 0.01 is high + # these tolerances are _also_ in real units + assert torch.allclose( + zero, + torch.zeros(1, device=zero.device, dtype=zero.dtype), + atol=1e-7 if model_dtype == default_dtype else 1e-2, + ) @pytest.mark.parametrize( @@ -213,14 +120,10 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, builder): ], ) def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): - - builder = IdentityModel + # TODO test metrics against one that goes all the way through + builder = IdentityModel # TODO: train a real model? dtype = str(torch.get_default_dtype())[len("torch.") :] - # if torch.cuda.is_available(): - # # TODO: is this true? - # pytest.skip("CUDA and subprocesses have issues") - path_to_this_file = pathlib.Path(__file__) config_path = path_to_this_file.parents[2] / f"configs/{conffile}" true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) @@ -238,7 +141,10 @@ def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): # We just don't add rescaling: true_config["model_builders"] = [builder] # We need truth labels as inputs for these fake testing models - true_config["_override_allow_truth_label_inputs"] = True + true_config["model_input_fields"] = { + AtomicDataDict.FORCE_KEY: "1o", + AtomicDataDict.TOTAL_ENERGY_KEY: "0e", + } for irun in range(3): @@ -255,13 +161,14 @@ def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): ) retcode = subprocess.run( - ["nequip-train", "conf.yaml"], + # Supress the warning cause we use general config for all the fake models + ["nequip-train", "conf.yaml", "--warn-unused"], cwd=tmpdir, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - retcode.check_returncode() + _check_and_print(retcode) # == Load metrics == dat = np.genfromtxt( diff --git a/tests/unit/data/test_dataloader.py b/tests/unit/data/test_dataloader.py index 5fbeeb93..8c7cc510 100644 --- a/tests/unit/data/test_dataloader.py +++ b/tests/unit/data/test_dataloader.py @@ -54,10 +54,13 @@ def test_subset_sampler(self, npz_dataset): print(batch) +NPZ_DATASET_FIXTURE_N_FRAMES: int = 8 + + @pytest.fixture(scope="module") def npz_dataset(): natoms = 3 - nframes = 8 + nframes = NPZ_DATASET_FIXTURE_N_FRAMES npz = dict( positions=np.random.random((nframes, natoms, 3)), force=np.random.random((nframes, natoms, 3)), @@ -69,7 +72,7 @@ def npz_dataset(): a = NpzDataset( file_name=folder + "/npzdata.npz", root=folder, - extra_fixed_fields={"r_max": 3}, + AtomicData_options={"r_max": 3}, ) yield a diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 95cfe48d..001f0c3c 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -14,6 +14,7 @@ AtomicInMemoryDataset, NpzDataset, ASEDataset, + HDF5Dataset, dataset_from_config, register_fields, deregister_fields, @@ -59,11 +60,32 @@ def npz_dataset(npz_data, temp_data): a = NpzDataset( file_name=npz_data, root=temp_data + "/test_dataset", - extra_fixed_fields={"r_max": 3}, + AtomicData_options={"r_max": 3}, ) yield a +@pytest.fixture(scope="function") +def hdf5_dataset(npz, temp_data): + try: + import h5py + except ModuleNotFoundError: + pytest.skip("h5py is not installed") + + with tempfile.NamedTemporaryFile(suffix=".hdf5") as path: + f = h5py.File(path.name, "w") + group = f.create_group("samples") + group.create_dataset("atomic_numbers", data=npz["Z"], dtype=np.int8) + group.create_dataset("pos", data=npz["positions"], dtype=np.float32) + group.create_dataset("energy", data=npz["energy"], dtype=np.float32) + group.create_dataset("forces", data=npz["force"], dtype=np.float32) + yield HDF5Dataset( + file_name=path.name, + root=temp_data + "/test_dataset", + AtomicData_options={"r_max": 3}, + ) + + @pytest.fixture(scope="function") def root(): with tempfile.TemporaryDirectory(prefix="datasetroot") as path: @@ -86,7 +108,7 @@ def test_init(self): assert str(excinfo.value) == "" def test_npz(self, npz_data, root): - g = NpzDataset(file_name=npz_data, root=root, extra_fixed_fields={"r_max": 3.0}) + g = NpzDataset(file_name=npz_data, root=root, AtomicData_options={"r_max": 3.0}) assert isdir(g.root) assert isdir(g.processed_dir) assert isfile(g.processed_dir + "/data.pth") @@ -95,7 +117,7 @@ def test_ase(self, ase_file, root): a = ASEDataset( file_name=ase_file, root=root, - extra_fixed_fields={"r_max": 3.0}, + AtomicData_options={"r_max": 3.0}, ase_args=dict(format="extxyz"), ) assert isdir(a.root) @@ -118,9 +140,10 @@ def test_callable(self, npz_dataset, npz): # By default we follow torch convention of defaulting to the unbiased std assert np.allclose(np.std(f_raveled, ddof=1), f_std) - def test_statistics(self, npz_dataset, npz): - - (eng_mean, eng_std), (Z_unique, Z_count) = npz_dataset.statistics( + @pytest.mark.parametrize("dataset_type", ["npz_dataset", "hdf5_dataset"]) + def test_statistics(self, dataset_type, npz, request): + dataset = request.getfixturevalue(dataset_type) + (eng_mean, eng_std), (Z_unique, Z_count) = dataset.statistics( fields=[AtomicDataDict.TOTAL_ENERGY_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY], modes=["mean_std", "count"], ) @@ -138,9 +161,9 @@ def test_statistics(self, npz_dataset, npz): assert np.all(Z_unique == uniq) assert np.all(Z_count == count) - def test_with_subset(self, npz_dataset, npz): - - dataset = npz_dataset.index_select([0]) + @pytest.mark.parametrize("dataset_type", ["npz_dataset", "hdf5_dataset"]) + def test_with_subset(self, dataset_type, npz, request): + dataset = request.getfixturevalue(dataset_type).index_select([0]) ((Z_unique, Z_count), (force_rms,)) = dataset.statistics( [AtomicDataDict.ATOMIC_NUMBERS_KEY, AtomicDataDict.FORCE_KEY], @@ -155,8 +178,10 @@ def test_with_subset(self, npz_dataset, npz): force_rms.numpy(), np.sqrt(np.mean(np.square(npz["force"][0]))) ) - def test_atom_types(self, npz_dataset): - ((avg_num_neigh, _),) = npz_dataset.statistics( + @pytest.mark.parametrize("dataset_type", ["npz_dataset", "hdf5_dataset"]) + def test_atom_types(self, dataset_type, request): + dataset = request.getfixturevalue(dataset_type) + ((avg_num_neigh, _),) = dataset.statistics( fields=[ lambda data: ( torch.unique( @@ -170,11 +195,13 @@ def test_atom_types(self, npz_dataset): # They are all homogenous in this dataset: assert ( avg_num_neigh - == torch.bincount(npz_dataset[0][AtomicDataDict.EDGE_INDEX_KEY][0])[0] + == torch.bincount(dataset[0][AtomicDataDict.EDGE_INDEX_KEY][0])[0] ) - def test_edgewise_stats(self, npz_dataset): - ((avg_edge_length, std_edge_len),) = npz_dataset.statistics( + @pytest.mark.parametrize("dataset_type", ["npz_dataset", "hdf5_dataset"]) + def test_edgewise_stats(self, dataset_type, request): + dataset = request.getfixturevalue(dataset_type) + ((avg_edge_length, std_edge_len),) = dataset.statistics( fields=[ lambda data: ( ( @@ -190,15 +217,21 @@ def test_edgewise_stats(self, npz_dataset): ], modes=["mean_std"], ) - collater = Collater.for_dataset(npz_dataset) - all_data = collater([npz_dataset[i] for i in range(len(npz_dataset))]) + collater = Collater.for_dataset(dataset) + all_data = collater([dataset[i] for i in range(len(dataset))]) all_data = AtomicData.to_AtomicDataDict(all_data) all_data = AtomicDataDict.with_edge_vectors(all_data, with_lengths=True) assert torch.allclose( - avg_edge_length, torch.mean(all_data[AtomicDataDict.EDGE_LENGTH_KEY]) + avg_edge_length, + torch.mean(all_data[AtomicDataDict.EDGE_LENGTH_KEY]).to( + avg_edge_length.dtype + ), ) assert torch.allclose( - std_edge_len, torch.std(all_data[AtomicDataDict.EDGE_LENGTH_KEY]) + std_edge_len, + torch.std(all_data[AtomicDataDict.EDGE_LENGTH_KEY]).to( + avg_edge_length.dtype + ), ) @@ -206,7 +239,7 @@ class TestPerAtomStatistics: @pytest.mark.parametrize("mode", ["mean_std", "rms"]) def test_per_node_field(self, npz_dataset, mode): # set up the transformer - npz_dataset = set_up_transformer(npz_dataset, True, False, False) + npz_dataset = set_up_transformer(npz_dataset, True, False) with pytest.raises(ValueError) as excinfo: npz_dataset.statistics( @@ -218,16 +251,15 @@ def test_per_node_field(self, npz_dataset, mode): == f"It doesn't make sense to ask for `{mode}` since `{AtomicDataDict.BATCH_KEY}` is not per-graph" ) - @pytest.mark.parametrize("fixed_field", [True, False]) @pytest.mark.parametrize("subset", [True, False]) @pytest.mark.parametrize( "key,dim", [(AtomicDataDict.TOTAL_ENERGY_KEY, (1,)), ("somekey", (3,))] ) - def test_per_graph_field(self, npz_dataset, fixed_field, subset, key, dim): + def test_per_graph_field(self, npz_dataset, subset, key, dim): if key == "somekey": register_fields(graph_fields=[key]) - npz_dataset = set_up_transformer(npz_dataset, True, fixed_field, subset) + npz_dataset = set_up_transformer(npz_dataset, True, subset) if npz_dataset is None: return @@ -262,14 +294,11 @@ def test_per_graph_field(self, npz_dataset, fixed_field, subset, key, dim): class TestPerSpeciesStatistics: - @pytest.mark.parametrize("fixed_field", [True, False]) @pytest.mark.parametrize("mode", ["mean_std", "rms"]) @pytest.mark.parametrize("subset", [True, False]) - def test_per_node_field(self, npz_dataset, fixed_field, mode, subset): + def test_per_node_field(self, npz_dataset, mode, subset): # set up the transformer - npz_dataset = set_up_transformer( - npz_dataset, not fixed_field, fixed_field, subset - ) + npz_dataset = set_up_transformer(npz_dataset, True, subset) (result,) = npz_dataset.statistics( [AtomicDataDict.BATCH_KEY], @@ -278,15 +307,13 @@ def test_per_node_field(self, npz_dataset, fixed_field, mode, subset): print(result) @pytest.mark.parametrize("alpha", [0, 1e-3, 0.01]) - @pytest.mark.parametrize("fixed_field", [True, False]) @pytest.mark.parametrize("full_rank", [True, False]) @pytest.mark.parametrize("subset", [True, False]) - def test_per_graph_field(self, npz_dataset, alpha, fixed_field, full_rank, subset): - + def test_per_graph_field(self, npz_dataset, alpha, full_rank, subset): if alpha <= 1e-4 and not full_rank: return - npz_dataset = set_up_transformer(npz_dataset, full_rank, fixed_field, subset) + npz_dataset = set_up_transformer(npz_dataset, full_rank, subset) if npz_dataset is None: return @@ -351,14 +378,14 @@ class TestReload: @pytest.mark.parametrize("give_url", [True, False]) @pytest.mark.parametrize("change_key_map", [True, False]) def test_reload(self, npz_dataset, npz_data, change_rmax, give_url, change_key_map): - r_max = npz_dataset.extra_fixed_fields["r_max"] + change_rmax + r_max = npz_dataset.AtomicData_options["r_max"] + change_rmax keymap = npz_dataset.key_mapping.copy() # the default one if change_key_map: keymap["x1"] = "x2" a = NpzDataset( file_name=npz_data, root=npz_dataset.root, - extra_fixed_fields={"r_max": r_max}, + AtomicData_options={"r_max": r_max}, key_mapping=keymap, **({"url": "example.com/data.dat"} if give_url else {}), ) @@ -373,10 +400,10 @@ class TestFromConfig: @pytest.mark.parametrize( "args", [ - dict(extra_fixed_fields={"r_max": 3.0}), - dict(dataset_extra_fixed_fields={"r_max": 3.0}), + dict(AtomicData_options={"r_max": 3.0}), + dict(dataset_AtomicData_options={"r_max": 3.0}), dict(r_max=3.0), - dict(r_max=3.0, extra_fixed_fields={}), + dict(r_max=3.0, AtomicData_options={}), ], ) def test_npz(self, npz_data, root, args): @@ -392,7 +419,7 @@ def test_npz(self, npz_data, root, args): ) ) g = dataset_from_config(config) - assert g.fixed_fields["r_max"] == 3 + assert g.AtomicData_options["r_max"] == 3 assert isdir(g.root) assert isdir(g.processed_dir) assert isfile(g.processed_dir + "/data.pth") @@ -403,7 +430,7 @@ def test_ase(self, ase_file, root, prefix): dict( file_name=ase_file, root=root, - extra_fixed_fields={"r_max": 3.0}, + AtomicData_options={"r_max": 3.0}, ase_args=dict(format="extxyz"), chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}, ) @@ -427,7 +454,7 @@ def test_ase(self, ase_file, root, prefix): class TestFromList: def test_from_atoms(self, molecules): dataset = ASEDataset.from_atoms_list( - molecules, extra_fixed_fields={"r_max": 4.5} + molecules, AtomicData_options={"r_max": 4.5} ) assert len(dataset) == len(molecules) for i, mol in enumerate(molecules): @@ -448,13 +475,8 @@ def generate_E(N, mean_min, mean_max, std): return ref_mean, ref_std, (N * E).sum(axis=-1) -def set_up_transformer(npz_dataset, full_rank, fixed_field, subset): - +def set_up_transformer(npz_dataset, full_rank, subset): if full_rank: - - if fixed_field: - return - unique = torch.unique(npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY]) npz_dataset.transform = TypeMapper( chemical_symbol_to_type={ @@ -466,19 +488,9 @@ def set_up_transformer(npz_dataset, full_rank, fixed_field, subset): # let all atoms to be the same type distribution num_nodes = npz_dataset.data[AtomicDataDict.BATCH_KEY].shape[0] - if fixed_field: - del npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY] - del npz_dataset.data.__slices__[ - AtomicDataDict.ATOMIC_NUMBERS_KEY - ] # remove batch metadata for the key - new_n = torch.ones(NATOMS, dtype=torch.int64) - new_n[0] += ntype - npz_dataset.fixed_fields[AtomicDataDict.ATOMIC_NUMBERS_KEY] = new_n - else: - npz_dataset.fixed_fields.pop(AtomicDataDict.ATOMIC_NUMBERS_KEY, None) - new_n = torch.ones(num_nodes, dtype=torch.int64) - new_n[::NATOMS] += ntype - npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY] = new_n + new_n = torch.ones(num_nodes, dtype=torch.int64) + new_n[::NATOMS] += ntype + npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY] = new_n # set up the transformer npz_dataset.transform = TypeMapper( diff --git a/tests/unit/data/test_sampler.py b/tests/unit/data/test_sampler.py new file mode 100644 index 00000000..1b249d65 --- /dev/null +++ b/tests/unit/data/test_sampler.py @@ -0,0 +1,84 @@ +import pytest +import itertools + +import torch + +from nequip.data import PartialSampler + +from test_dataloader import npz_dataset, NPZ_DATASET_FIXTURE_N_FRAMES # noqa + + +@pytest.fixture(params=[True, False], scope="module") +def shuffle(request) -> bool: + return request.param + + +@pytest.fixture( + params=[None, 1, 2, 5, 7, NPZ_DATASET_FIXTURE_N_FRAMES], scope="function" +) +def sampler(request, npz_dataset, shuffle) -> PartialSampler: # noqa: F811 + return PartialSampler( + data_source=npz_dataset, + shuffle=shuffle, + num_samples_per_epoch=request.param, + generator=torch.Generator().manual_seed(0), + ) + + +def test_partials_add_up(sampler: PartialSampler): + """Confirm that full data epochs are (random permutations of) the list of all dataset indexes""" + seq = [] + for epoch_i in range(2 * sampler.num_samples_total + 1): + sampler.step_epoch(epoch_i) + seq.extend(iter(sampler)) + + seq = [int(e) for e in seq] + + if sampler.shuffle: + # make sure we've at least hit every frame once + assert set(seq) == set(range(sampler.num_samples_total)) + # then go through it by dataset epochs + i = 0 + while True: + data_epoch_idexes = seq[i : i + sampler.num_samples_total] + if len(data_epoch_idexes) == 0: + break + if len(data_epoch_idexes) == sampler.num_samples_total: + # it should be a random permutation + assert set(data_epoch_idexes) == set(range(sampler.num_samples_total)) + elif len(data_epoch_idexes) < sampler.num_samples_total: + # we hae a partial dataset epoch at the end + assert set(data_epoch_idexes) <= set(range(sampler.num_samples_total)) + assert len(set(data_epoch_idexes)) == len(data_epoch_idexes) + else: + assert False + i += sampler.num_samples_total + else: + # make sure its a repeating sequence of aranges + assert ( + seq + == list( + itertools.chain( + *[ + range(sampler.num_samples_total) + for _ in range(sampler._epoch + 2) + ] + ) + )[: len(seq)] + ) + + +def test_epoch_count(sampler: PartialSampler): + with pytest.raises(AssertionError): + list(iter(sampler)) + sampler.step_epoch(0) + assert sampler._epoch == 0 + assert sampler._prev_epoch is None + list(iter(sampler)) + assert sampler._prev_epoch == 0 + with pytest.raises(AssertionError): + list(iter(sampler)) + sampler.step_epoch(1) + list(iter(sampler)) + assert sampler._epoch == 1 + assert sampler._prev_epoch == 1 # since that's the prev epoch we've just completed diff --git a/tests/unit/model/test_builder_utils.py b/tests/unit/model/test_builder_utils.py index caebde55..c90327d8 100644 --- a/tests/unit/model/test_builder_utils.py +++ b/tests/unit/model/test_builder_utils.py @@ -27,7 +27,7 @@ def test_avg_num(molecules, temp_data, r_max, subset, to_test): nequip_dataset = ASEDataset( file_name=fp.name, root=temp_data, - extra_fixed_fields={"r_max": r_max}, + AtomicData_options={"r_max": r_max}, ase_args=dict(format="extxyz"), type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), ) diff --git a/tests/unit/model/test_nequip_model.py b/tests/unit/model/test_nequip_model.py index 2aa82e15..ee4d9ab5 100644 --- a/tests/unit/model/test_nequip_model.py +++ b/tests/unit/model/test_nequip_model.py @@ -10,7 +10,10 @@ COMMON_CONFIG = { "avg_num_neighbors": None, "num_types": 3, - "types_names": ["H", "C", "O"], + "chemical_symbol_to_type": {"H": 0, "C": 1, "O": 2}, + # Just in case for when that builder exists: + "pair_style": "ZBL", + "units": "metal", } r_max = 3 minimal_config1 = dict( @@ -78,15 +81,26 @@ def base_config(self, request): AtomicDataDict.FORCE_KEY, ], ), + # # Save some time in the tests + # ( + # ["EnergyModel"], + # [ + # AtomicDataDict.TOTAL_ENERGY_KEY, + # AtomicDataDict.PER_ATOM_ENERGY_KEY, + # ], + # ), ( - ["EnergyModel"], + ["EnergyModel", "StressForceOutput"], [ AtomicDataDict.TOTAL_ENERGY_KEY, AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.STRESS_KEY, + AtomicDataDict.VIRIAL_KEY, ], ), ( - ["EnergyModel", "StressForceOutput"], + ["EnergyModel", "PairPotentialTerm", "StressForceOutput"], [ AtomicDataDict.TOTAL_ENERGY_KEY, AtomicDataDict.PER_ATOM_ENERGY_KEY, @@ -109,14 +123,14 @@ def test_submods(self): config = minimal_config2.copy() config["model_builders"] = ["EnergyModel"] model = model_from_config(config=config, initialize=True) - assert isinstance(model.chemical_embedding, AtomwiseLinear) + chemical_embedding = model.model.chemical_embedding + assert isinstance(chemical_embedding, AtomwiseLinear) true_irreps = o3.Irreps(minimal_config2["chemical_embedding_irreps_out"]) assert ( - model.chemical_embedding.irreps_out[model.chemical_embedding.out_field] - == true_irreps + chemical_embedding.irreps_out[chemical_embedding.out_field] == true_irreps ) # Make sure it propagates assert ( - model.layer0_convnet.irreps_in[model.chemical_embedding.out_field] + model.model.layer0_convnet.irreps_in[chemical_embedding.out_field] == true_irreps ) diff --git a/tests/unit/model/test_pair/.gitignore b/tests/unit/model/test_pair/.gitignore new file mode 100644 index 00000000..686a8db1 --- /dev/null +++ b/tests/unit/model/test_pair/.gitignore @@ -0,0 +1,2 @@ +log.lammps +zbl.dat \ No newline at end of file diff --git a/tests/unit/model/test_pair/test_zbl.py b/tests/unit/model/test_pair/test_zbl.py new file mode 100644 index 00000000..b862b624 --- /dev/null +++ b/tests/unit/model/test_pair/test_zbl.py @@ -0,0 +1,94 @@ +import pytest + +import numpy as np +from pathlib import Path + +import ase +import ase.io +import ase.data + +import torch + +from nequip.data.transforms import TypeMapper +from nequip.data import AtomicDataDict +from nequip.model import model_from_config +from nequip.ase import NequIPCalculator +from nequip.utils import Config +from nequip.utils.unittests.model_tests import BaseEnergyModelTests + + +class TestNequIPModel(BaseEnergyModelTests): + @pytest.fixture + def strict_locality(self): + return True + + @pytest.fixture( + params=[False, True], + scope="class", + ) + def config(self, request): + do_scale = request.param + config = { + "model_builders": [ + "PairPotential", + "ForceOutput", + "RescaleEnergyEtc", + ], + "global_rescale_scale": 3.7777 if do_scale else None, + "pair_style": "ZBL", + "units": "metal", + "r_max": 5.0, + "chemical_symbol_to_type": {"H": 0, "C": 1, "O": 2}, + } + return config, [ + AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + ] + + def test_lammps_repro(self, config): + if torch.get_default_dtype() != torch.float64: + pytest.skip() + config, _ = config + config = config.copy() + r_max: float = 8.0 # see zbl_data.lmps + config.update( + { + "model_dtype": "float64", + "r_max": r_max + 1, # To make cutoff envelope irrelevant + "PolynomialCutoff_p": 80, # almost a step function + } + ) + config["chemical_symbol_to_type"] = { + "H": 0, + "O": 1, + "C": 2, + "N": 3, + "Cu": 4, + "Au": 5, + } + tm = TypeMapper(chemical_symbol_to_type=config["chemical_symbol_to_type"]) + config["num_types"] = tm.num_types + ZBL_model = model_from_config(Config.from_dict(config), initialize=True) + ZBL_model.eval() + # make test system of two atoms: + atoms = ase.Atoms(positions=np.zeros((2, 3)), symbols=["H", "H"]) + atoms.calc = NequIPCalculator( + ZBL_model, r_max=r_max, device="cpu", transform=tm + ) + # == load precomputed reference data == + # To regenerate this data, run + # $ lmp -in zbl_data.lmps + # $ python -c "import numpy as np; d = np.loadtxt('zbl.dat', skiprows=1); np.save('zbl.npy', d)" + refdata = np.load(Path(__file__).parent / "zbl.npy") + for (r, Zi, Zj, pe, fxi, fxj) in refdata: + if r >= r_max: + continue + atoms.positions[1, 0] = r + atoms.set_atomic_numbers([int(Zi), int(Zj)]) + # ZBL blows up for atoms being close, so the numerics differ to ours + # 1e-5 == 0.01 meV / Å + assert np.allclose(atoms.get_forces()[0, 0], fxi, atol=1e-5) + assert np.allclose(atoms.get_forces()[1, 0], fxj, atol=1e-5) + # 1e-4 == 0.1 meV system, 0.05 meV / atom + assert np.allclose(atoms.get_potential_energy(), pe, atol=1e-4) diff --git a/tests/unit/model/test_pair/zbl.npy b/tests/unit/model/test_pair/zbl.npy new file mode 100644 index 00000000..626bccde Binary files /dev/null and b/tests/unit/model/test_pair/zbl.npy differ diff --git a/tests/unit/model/test_pair/zbl_data.lmps b/tests/unit/model/test_pair/zbl_data.lmps new file mode 100644 index 00000000..6afa6e0f --- /dev/null +++ b/tests/unit/model/test_pair/zbl_data.lmps @@ -0,0 +1,46 @@ +units metal +atom_style atomic +atom_modify map yes +thermo 1 + +region 1 block 0 10 0 10 0 10 +boundary s s s +create_box 2 1 +create_atoms 1 single 0.0 0.0 0.0 +create_atoms 2 single 0.1 0.0 0.0 + +mass 1 1.0 +mass 2 1.0 + +group 2 type 2 + +neighbor 1.0 nsq # tiny box +neigh_modify delay 0 every 1 check no + +variable rmax string 8.0 +variable N string 50 + +pair_style zbl $(v_rmax) $(v_rmax) +print "r Zi Zj pe fxi fxj" file zbl.dat + +variable Zi index 1.0 6.0 7.0 8.0 29.0 79.0 +label Ziloop + + variable Zj index 1.0 6.0 7.0 8.0 29.0 79.0 + label Zjloop + pair_coeff 1 1 $(v_Zi) $(v_Zi) + pair_coeff 2 2 $(v_Zj) $(v_Zj) + + variable i loop $(v_N) + label rloop + set atom 2 x $(v_i * v_rmax / v_N) + run 0 + print "$(x[2]) $(v_Zi) $(v_Zj) $(pe) $(fx[1]) $(fx[2])" append zbl.dat + next i + jump SELF rloop + + next Zj + jump SELF Zjloop + +next Zi +jump SELF Ziloop \ No newline at end of file diff --git a/tests/unit/trainer/test_trainer.py b/tests/unit/trainer/test_trainer.py index 860be357..197f3897 100644 --- a/tests/unit/trainer/test_trainer.py +++ b/tests/unit/trainer/test_trainer.py @@ -14,7 +14,7 @@ from nequip.data import AtomicDataDict from nequip.train.trainer import Trainer from nequip.utils.savenload import load_file -from nequip.nn import GraphModuleMixin +from nequip.nn import GraphModuleMixin, GraphModel, RescaleOutput def dummy_builder(): @@ -45,16 +45,17 @@ def dummy_builder(): ) -@pytest.fixture(scope="class") -def trainer(): +@pytest.fixture(scope="function") +def trainer(float_tolerance): """ Generate a class instance with minimal configurations """ - minimal_config["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] - model = model_from_config(minimal_config) + conf = minimal_config.copy() + conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] + model = model_from_config(conf) with tempfile.TemporaryDirectory(prefix="output") as path: - minimal_config["root"] = path - c = Trainer(model=model, **minimal_config) + conf["root"] = path + c = Trainer(model=model, **conf) yield c @@ -73,14 +74,14 @@ def test_duplicate_id_2(self, temp_data): check whether the Output class can automatically insert timestr when a workdir has pre-existed """ + conf = minimal_config.copy() + conf["root"] = temp_data - minimal_config["root"] = temp_data - - model = DummyNet(3) - Trainer(model=model, **minimal_config) + model = GraphModel(DummyNet(3)) + Trainer(model=model, **conf) with pytest.raises(RuntimeError): - Trainer(model=model, **minimal_config) + Trainer(model=model, **conf) class TestSaveLoad: @@ -145,11 +146,17 @@ def test_from_file(self, trainer, append): class TestData: @pytest.mark.parametrize("mode", ["random", "sequential"]) def test_split(self, trainer, nequip_dataset, mode): - trainer.train_val_split = mode trainer.set_dataset(nequip_dataset) - for i, batch in enumerate(trainer.dl_train): - print(i, batch) + for epoch_i in range(3): + trainer.dl_train_sampler.step_epoch(epoch_i) + n_samples: int = 0 + for i, batch in enumerate(trainer.dl_train): + n_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + if trainer.n_train_per_epoch is not None: + assert n_samples == trainer.n_train_per_epoch + else: + assert n_samples == trainer.n_train class TestTrain: @@ -281,15 +288,19 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: return data -class DummyScale(torch.nn.Module): +# subclass to make sure it gets picked up by GraphModel +class DummyScale(RescaleOutput): """mimic the rescale model""" def __init__(self, key, scale, shift) -> None: - super().__init__() + torch.nn.Module.__init__(self) # skip RescaleOutput's __init__ self.key = key self.scale_by = torch.as_tensor(scale, dtype=torch.get_default_dtype()) self.shift_by = torch.as_tensor(shift, dtype=torch.get_default_dtype()) self.linear2 = Linear(3, 3) + self.irreps_in = {} + self.irreps_out = {key: "3x0e"} + self.model = None def forward(self, data): out = self.linear2(data["pos"]) @@ -317,7 +328,8 @@ def unscale(self, data, force_process=False): def scale_train(nequip_dataset): with tempfile.TemporaryDirectory(prefix="output") as path: trainer = Trainer( - model=DummyScale(AtomicDataDict.FORCE_KEY, scale=1.3, shift=1), + model=GraphModel(DummyScale(AtomicDataDict.FORCE_KEY, scale=1.3, shift=1)), + seed=9, n_train=4, n_val=4, max_epochs=0, diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index 0cd3151e..35ae7b68 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -36,12 +36,12 @@ def test_init(self, config): @config_testlist def test_set_attr(self, config): - dict_config = dict(config) + dict_config = Config.as_dict(config) config.intv = 2 dict_config["intv"] = 2 - assert dict(config) == dict_config - print("dict", dict(config)) + assert Config.as_dict(config) == dict_config + print("dict", Config.as_dict(config)) @config_testlist def test_get_attr(self, config): @@ -69,7 +69,7 @@ def test_save_yaml(self, config): @one_test def test_load_yaml(self, config): config2 = config.load(filename=f"{self.filename}.yaml") - assert dict(config) == dict(config2) + assert Config.as_dict(config) == dict(config2) remove(f"{self.filename}.yaml") @@ -81,14 +81,14 @@ class TestConfigUpdate: @config_testlist def test_update(self, config): - dict_config = dict(config) + dict_config = Config.as_dict(config) dict_config["new_intv"] = 9 newdict = {"new_intv": 9} config.update(newdict) - assert dict(config) == dict_config + assert Config.as_dict(config) == dict_config @config_testlist def test_update_settype(self, config): diff --git a/tests/unit/utils/test_gmm.py b/tests/unit/utils/test_gmm.py new file mode 100644 index 00000000..84628ddd --- /dev/null +++ b/tests/unit/utils/test_gmm.py @@ -0,0 +1,123 @@ +import torch +import pytest +import numpy as np +from nequip.utils import gmm +from sklearn import mixture +from e3nn.util.test import assert_auto_jitable + + +class TestGMM: + # Seed for tests + @pytest.fixture + def seed(self): + return 678912345 + + # Data sets for fitting GMMs and scoring NLLs + @pytest.fixture(params=[[10, 8], [500, 32]]) + def feature_data(self, seed, request): + fit_data = 2 * ( + torch.randn( + request.param[0], + request.param[1], + generator=torch.Generator().manual_seed(seed), + ) + - 0.5 + ) + test_data = 2 * ( + torch.randn( + request.param[0] * 2, + request.param[1], + generator=torch.Generator().manual_seed(seed - 123456789), + ) + - 0.5 + ) + return {"fit_data": fit_data, "test_data": test_data} + + # Sklearn GMM for tests + @pytest.fixture + def gmm_sklearn(self, seed): + return mixture.GaussianMixture( + n_components=8, covariance_type="full", random_state=seed + ) + + # Torch GMM for small data set tests + @pytest.fixture + def gmm_torch(self, feature_data): + return gmm.GaussianMixture( + n_features=feature_data["fit_data"].size(dim=1), n_components=8 + ) + + # Test compilation + def test_compile(self, gmm_torch): + assert_auto_jitable(gmm_torch) + + # Test agreement between sklearn and torch GMMs + def test_fit_forward(self, seed, gmm_sklearn, gmm_torch, feature_data): + gmm_sklearn.fit(feature_data["fit_data"].numpy()) + gmm_torch.fit(feature_data["fit_data"], rng=seed) + + assert torch.allclose(torch.from_numpy(gmm_sklearn.means_), gmm_torch.means) + assert torch.allclose( + torch.from_numpy(gmm_sklearn.covariances_), gmm_torch.covariances + ) + assert torch.allclose(torch.from_numpy(gmm_sklearn.weights_), gmm_torch.weights) + assert torch.allclose( + torch.from_numpy(gmm_sklearn.precisions_cholesky_), + gmm_torch.precisions_cholesky, + ) + + sklearn_nll = gmm_sklearn.score_samples(feature_data["test_data"].numpy()) + torch_nll = gmm_torch(feature_data["test_data"]) + + assert torch.allclose(-torch.from_numpy(sklearn_nll), torch_nll) + + # Test agreement between sklearn and torch using BIC + def test_fit_forward_bic(self, seed, feature_data): + components = list(range(1, min(50, feature_data["fit_data"].size(dim=0)))) + gmms = [ + mixture.GaussianMixture( + n_components=n, covariance_type="full", random_state=seed + ) + for n in components + ] + bics = [ + model.fit(feature_data["fit_data"]).bic(feature_data["fit_data"]) + for model in gmms + ] + + gmm_sklearn = mixture.GaussianMixture( + n_components=components[np.argmin(bics)], + covariance_type="full", + random_state=seed, + ) + gmm_torch = gmm.GaussianMixture(n_features=feature_data["fit_data"].size(dim=1)) + + gmm_sklearn.fit(feature_data["fit_data"].numpy()) + gmm_torch.fit(feature_data["fit_data"], rng=seed) + + assert torch.allclose(torch.from_numpy(gmm_sklearn.means_), gmm_torch.means) + assert torch.allclose( + torch.from_numpy(gmm_sklearn.covariances_), gmm_torch.covariances + ) + assert torch.allclose(torch.from_numpy(gmm_sklearn.weights_), gmm_torch.weights) + assert torch.allclose( + torch.from_numpy(gmm_sklearn.precisions_cholesky_), + gmm_torch.precisions_cholesky, + ) + + sklearn_nll = gmm_sklearn.score_samples(feature_data["test_data"].numpy()) + torch_nll = gmm_torch(feature_data["test_data"]) + + assert torch.allclose(-torch.from_numpy(sklearn_nll), torch_nll) + + # Test assertion error for covariance type other than "full" + def test_full_cov(self): + with pytest.raises(AssertionError) as excinfo: + _ = gmm.GaussianMixture(n_features=2, covariance_type="tied") + assert "covariance type was tied, should be full" in str(excinfo.value) + + # Test assertion error for evaluating unfitted GMM + def test_unfitted_gmm(self, gmm_torch, feature_data): + with pytest.raises(AssertionError) as excinfo: + _ = gmm_torch(feature_data["test_data"]) + assert "model has not been fitted" in str(excinfo.value)