Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fix-pmg' into fix-pmg
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 17, 2025
2 parents 329ca1e + 09f97fc commit 90a489b
Show file tree
Hide file tree
Showing 24 changed files with 93 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: ["3.7", "3.8", "3.12"]
python-version: ["3.8", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
# Python
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.4
rev: v0.9.1
hooks:
- id: ruff
args: ["--fix"]
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

DP-GEN only supports Python 3.7 and above. You can [setup a conda/pip environment](https://docs.deepmodeling.com/faq/conda.html), and then use one of the following methods to install DP-GEN:
dpdata only supports Python 3.8 and above. You can [setup a conda/pip environment](https://docs.deepmodeling.com/faq/conda.html), and then use one of the following methods to install dpdata:

- Install via pip: `pip install dpdata`
- Install via conda: `conda install -c conda-forge dpdata`
Expand Down
7 changes: 1 addition & 6 deletions docs/make_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@

import csv
import os
import sys
from collections import defaultdict
from inspect import Parameter, Signature, cleandoc, signature

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from typing import Literal

from numpydoc.docscrape import Parameter as numpydoc_Parameter
from numpydoc.docscrape_sphinx import SphinxDocString
Expand Down
6 changes: 3 additions & 3 deletions dpdata/abacus/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def get_coords_from_dump(dumplines, natoms):
if "VIRIAL" in dumplines[6]:
calc_stress = True
check_line = 10
assert (
"POSITION" in dumplines[check_line]
), "keywords 'POSITION' cannot be found in the 6th line. Please check."
assert "POSITION" in dumplines[check_line], (
"keywords 'POSITION' cannot be found in the 6th line. Please check."
)
if "FORCE" in dumplines[check_line]:
calc_force = True

Expand Down
3 changes: 2 additions & 1 deletion dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def calculate(
self.results["energy"] = data["energies"][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
self.results["free_energy"] = data["energies"][0]
self.results["forces"] = data["forces"][0]
if "forces" in data:
self.results["forces"] = data["forces"][0]
if "virials" in data:
self.results["virial"] = data["virials"][0].reshape(3, 3)

Expand Down
6 changes: 3 additions & 3 deletions dpdata/deepmd/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def to_system_data(folder, type_map=None, labels=True):
if type_map is not None:
assert isinstance(type_map, list)
missing_type = [i for i in old_type_map if i not in type_map]
assert (
not missing_type
), f"These types are missing in selected type_map: {missing_type} !"
assert not missing_type, (
f"These types are missing in selected type_map: {missing_type} !"
)
index_map = np.array([type_map.index(i) for i in old_type_map])
data["atom_names"] = type_map.copy()
else:
Expand Down
3 changes: 2 additions & 1 deletion dpdata/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def label(self, data: dict) -> dict:
labeled_data = lb_data.copy()
else:
labeled_data["energies"] += lb_data["energies"]
labeled_data["forces"] += lb_data["forces"]
if "forces" in labeled_data and "forces" in lb_data:
labeled_data["forces"] += lb_data["forces"]
if "virials" in labeled_data and "virials" in lb_data:
labeled_data["virials"] += lb_data["virials"]
return labeled_data
Expand Down
12 changes: 9 additions & 3 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
cell=data["cells"][ii],
)

results = {"energy": data["energies"][ii], "forces": data["forces"][ii]}
results = {"energy": data["energies"][ii]}
if "forces" in data:
results["forces"] = data["forces"][ii]
if "virials" in data:
# convert to GPa as this is ase convention
# v_pref = 1 * 1e4 / 1.602176621e6
Expand Down Expand Up @@ -296,7 +298,10 @@ def from_labeled_system(
dict_frames["energies"] = np.append(
dict_frames["energies"], tmp["energies"][0]
)
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
if "forces" in tmp.keys() and "forces" in dict_frames.keys():
dict_frames["forces"] = np.append(
dict_frames["forces"], tmp["forces"][0]
)
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
dict_frames["virials"] = np.append(
dict_frames["virials"], tmp["virials"][0]
Expand All @@ -305,7 +310,8 @@ def from_labeled_system(
## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "forces" in dict_frames.keys():
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "virials" in dict_frames.keys():
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)

Expand Down
18 changes: 9 additions & 9 deletions dpdata/plugins/n2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ def from_labeled_system(self, file_name: FileType, **kwargs):
energy = None
elif line.lower() == "end":
# If we are at the end of a section, process the section
assert (
len(coord) == len(atype) == len(force)
), "Number of atoms, atom types, and forces must match."
assert len(coord) == len(atype) == len(force), (
"Number of atoms, atom types, and forces must match."
)

# Check if the number of atoms is consistent across all frames
natom = len(coord)
if natom0 is None:
natom0 = natom
else:
assert (
natom == natom0
), "The number of atoms in all frames must be the same."
assert natom == natom0, (
"The number of atoms in all frames must be the same."
)

# Check if the number of atoms of each type is consistent across all frames
atype = np.array(atype)
Expand All @@ -108,9 +108,9 @@ def from_labeled_system(self, file_name: FileType, **kwargs):
if natoms0 is None:
natoms0 = natoms
else:
assert (
natoms == natoms0
), "The number of atoms of each type in all frames must be the same."
assert natoms == natoms0, (
"The number of atoms of each type in all frames must be the same."
)
if atom_types0 is None:
atom_types0 = atype
atom_order = match_indices(atom_types0, atype)
Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/pwmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def from_labeled_system(
data["cells"],
data["coords"],
data["energies"],
data["forces"],
tmp_force,
tmp_virial,
) = dpdata.pwmat.movement.get_frames(
file_name, begin=begin, step=step, convergence_check=convergence_check
)
if tmp_force is not None:
data["forces"] = tmp_force
if tmp_virial is not None:
data["virials"] = tmp_virial
# scale virial to the unit of eV
Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def from_labeled_system(
data["cells"],
data["coords"],
data["energies"],
data["forces"],
tmp_force,
tmp_virial,
) = dpdata.vasp.outcar.get_frames(
file_name,
Expand All @@ -104,6 +104,8 @@ def from_labeled_system(
ml=ml,
convergence_check=convergence_check,
)
if tmp_force is not None:
data["forces"] = tmp_force
if tmp_virial is not None:
data["virials"] = tmp_virial
# scale virial to the unit of eV
Expand Down
20 changes: 10 additions & 10 deletions dpdata/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ class ErrorsBase(metaclass=ABCMeta):
SYSTEM_TYPE = object

def __init__(self, system_1: SYSTEM_TYPE, system_2: SYSTEM_TYPE) -> None:
assert isinstance(
system_1, self.SYSTEM_TYPE
), f"system_1 should be {self.SYSTEM_TYPE.__name__}"
assert isinstance(
system_2, self.SYSTEM_TYPE
), f"system_2 should be {self.SYSTEM_TYPE.__name__}"
assert isinstance(system_1, self.SYSTEM_TYPE), (
f"system_1 should be {self.SYSTEM_TYPE.__name__}"
)
assert isinstance(system_2, self.SYSTEM_TYPE), (
f"system_2 should be {self.SYSTEM_TYPE.__name__}"
)
self.system_1 = system_1
self.system_2 = system_2

Expand Down Expand Up @@ -116,15 +116,15 @@ class Errors(ErrorsBase):
SYSTEM_TYPE = LabeledSystem

@property
@lru_cache()
@lru_cache
def e_errors(self) -> np.ndarray:
"""Energy errors."""
assert isinstance(self.system_1, self.SYSTEM_TYPE)
assert isinstance(self.system_2, self.SYSTEM_TYPE)
return self.system_1["energies"] - self.system_2["energies"]

@property
@lru_cache()
@lru_cache
def f_errors(self) -> np.ndarray:
"""Force errors."""
assert isinstance(self.system_1, self.SYSTEM_TYPE)
Expand Down Expand Up @@ -153,7 +153,7 @@ class MultiErrors(ErrorsBase):
SYSTEM_TYPE = MultiSystems

@property
@lru_cache()
@lru_cache
def e_errors(self) -> np.ndarray:
"""Energy errors."""
assert isinstance(self.system_1, self.SYSTEM_TYPE)
Expand All @@ -166,7 +166,7 @@ def e_errors(self) -> np.ndarray:
return np.concatenate(errors)

@property
@lru_cache()
@lru_cache
def f_errors(self) -> np.ndarray:
"""Force errors."""
assert isinstance(self.system_1, self.SYSTEM_TYPE)
Expand Down
22 changes: 13 additions & 9 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@
import hashlib
import numbers
import os
import sys
import warnings
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
overload,
)

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

import numpy as np

import dpdata
Expand Down Expand Up @@ -1209,7 +1204,11 @@ class LabeledSystem(System):
DTYPES: tuple[DataType, ...] = System.DTYPES + (
DataType("energies", np.ndarray, (Axis.NFRAMES,), deepmd_name="energy"),
DataType(
"forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force"
"forces",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="force",
),
DataType(
"virials",
Expand Down Expand Up @@ -1269,13 +1268,17 @@ def __add__(self, others):
raise RuntimeError("Unspported data structure")
return self.__class__.from_dict({"data": self_copy.data})

def has_forces(self) -> bool:
return "forces" in self.data

def has_virial(self) -> bool:
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
return "virials" in self.data

def affine_map_fv(self, trans, f_idx: int | numbers.Integral):
assert np.linalg.det(trans) != 0
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
if self.has_forces():
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
if self.has_virial():
self.data["virials"][f_idx] = np.matmul(
trans.T, np.matmul(self.data["virials"][f_idx], trans)
Expand Down Expand Up @@ -1308,7 +1311,8 @@ def correction(self, hl_sys: LabeledSystem) -> LabeledSystem:
raise RuntimeError("high_sys should be LabeledSystem")
corrected_sys = self.copy()
corrected_sys.data["energies"] = hl_sys.data["energies"] - self.data["energies"]
corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"]
if "forces" in self.data and "forces" in hl_sys.data:
corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"]
if "virials" in self.data and "virials" in hl_sys.data:
corrected_sys.data["virials"] = (
hl_sys.data["virials"] - self.data["virials"]
Expand Down
7 changes: 1 addition & 6 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@

import io
import os
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Generator, overload
from typing import TYPE_CHECKING, Generator, Literal, overload

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
import numpy as np

from dpdata.periodic_table import Element
Expand Down
6 changes: 3 additions & 3 deletions dpdata/vasp/outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def analyze_block(lines, ntot, nelm, ml=False):
not lines[idx + in_kB_index].split()[0:2] == ["in", "kB"]
):
in_kB_index += 1
assert idx + in_kB_index < len(
lines
), 'ERROR: "in kB" is not found in OUTCAR. Unable to extract virial.'
assert idx + in_kB_index < len(lines), (
'ERROR: "in kB" is not found in OUTCAR. Unable to extract virial.'
)
tmp_v = [float(ss) for ss in lines[idx + in_kB_index].split()[2:8]]
virial = np.zeros([3, 3])
virial[0][0] = tmp_v[0]
Expand Down
6 changes: 3 additions & 3 deletions dpdata/vasp/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


def check_name(item, name):
assert (
item.attrib["name"] == name
), "item attrib '{}' dose not math required '{}'".format(item.attrib["name"], name)
assert item.attrib["name"] == name, (
"item attrib '{}' dose not math required '{}'".format(item.attrib["name"], name)
)


def get_varray(varray):
Expand Down
10 changes: 5 additions & 5 deletions dpdata/xyz/quip_gap_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def handle_single_xyz_frame(lines):
list(filter(bool, field_dict["virial"].split(" ")))
).reshape(3, 3)
]
).astype("float32")
).astype(np.float64)
else:
virials = None

Expand All @@ -175,10 +175,10 @@ def handle_single_xyz_frame(lines):
3, 3
)
]
).astype("float32")
info_dict["coords"] = np.array([coords_array]).astype("float32")
info_dict["energies"] = np.array([field_dict["energy"]]).astype("float32")
info_dict["forces"] = np.array([force_array]).astype("float32")
).astype(np.float64)
info_dict["coords"] = np.array([coords_array]).astype(np.float64)
info_dict["energies"] = np.array([field_dict["energy"]]).astype(np.float64)
info_dict["forces"] = np.array([force_array]).astype(np.float64)
if virials is not None:
info_dict["virials"] = virials
info_dict["orig"] = np.zeros(3)
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ authors = [
]
license = {file = "LICENSE"}
classifiers = [
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand All @@ -28,7 +27,7 @@ dependencies = [
'importlib_metadata>=1.4; python_version < "3.8"',
'typing_extensions; python_version < "3.8"',
]
requires-python = ">=3.7"
requires-python = ">=3.8"
readme = "README.md"
keywords = ["lammps", "vasp", "deepmd-kit"]

Expand Down
Loading

0 comments on commit 90a489b

Please sign in to comment.