Skip to content

Commit

Permalink
Update for reading pypolymlp MLPs from file
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Sep 3, 2024
1 parent f38a704 commit b7fea1a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
7 changes: 6 additions & 1 deletion phono3py/api_phono3py.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PypolymlpParams,
develop_polymlp,
evalulate_polymlp,
load_polymlp,
parse_mlp_params,
)
from phonopy.structure.atoms import PhonopyAtoms
Expand Down Expand Up @@ -2226,8 +2227,12 @@ def develop_mlp(
verbose=self._log_level - 1 > 0,
)

def load_mlp(self, filename: str = "pypolymlp.mlp"):
"""Load machine learning potential of pypolymlp."""
self._mlp = load_polymlp(filename=filename)

def evaluate_mlp(self):
"""Evaluate the machine learning potential of pypolymlp.
"""Evaluate machine learning potential of pypolymlp.
This method calculates the supercell energies and forces from the MLP
for the displacements in self._dataset of type 2. The results are stored
Expand Down
21 changes: 14 additions & 7 deletions phono3py/cui/create_force_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def _read_dataset_fc3(
file_exists(e.filename, log_level=log_level)

if use_pypolymlp:
phono3py.mlp_dataset = dataset
if forces_in_dataset(dataset):
phono3py.mlp_dataset = dataset
run_pypolymlp_to_compute_forces(
phono3py,
mlp_params,
Expand All @@ -521,6 +522,7 @@ def run_pypolymlp_to_compute_forces(
displacement_distance: Optional[float] = None,
number_of_snapshots: Optional[int] = None,
random_seed: Optional[int] = None,
mlp_filename: str = "pypolymlp.mlp",
log_level: int = 0,
):
"""Run pypolymlp to compute forces."""
Expand All @@ -536,10 +538,18 @@ def run_pypolymlp_to_compute_forces(
print(f" {k}: {v}")
if log_level > 1:
print("")
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)

ph3py.develop_mlp(params=mlp_params)
if forces_in_dataset(ph3py.mlp_dataset):
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
else:
if pathlib.Path(mlp_filename).exists():
if log_level:
print(f'Load MLPs from "{mlp_filename}".')
ph3py.load_mlp(mlp_filename)
else:
raise RuntimeError(f'"{mlp_filename}" is not found.')

if log_level:
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)
Expand Down Expand Up @@ -577,9 +587,6 @@ def run_pypolymlp_to_compute_forces(
flush=True,
)

if ph3py.mlp_dataset is None:
msg = "mlp_dataset has to be set before calling this method."
raise RuntimeError(msg)
if ph3py.supercells_with_displacements is None:
raise RuntimeError("Displacements are not set. Run generate_displacements.")

Expand Down
7 changes: 4 additions & 3 deletions phono3py/cui/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ def set_dataset_and_force_constants(
)
if not read_fc["fc3"]:
if use_pypolymlp:
ph3py.mlp_dataset = dataset
if forces_in_dataset(dataset):
ph3py.mlp_dataset = dataset
else:
ph3py.dataset = dataset
read_fc["fc2"], phonon_dataset = _get_dataset_phonon_dataset_or_fc2(
Expand Down Expand Up @@ -461,8 +462,8 @@ def compute_force_constants_from_datasets(
"""
fc3_calculator = extract_fc2_fc3_calculators(fc_calculator, 3)
fc2_calculator = extract_fc2_fc3_calculators(fc_calculator, 2)
if not read_fc["fc3"] and (ph3py.dataset or ph3py.mlp_dataset):
if use_pypolymlp and forces_in_dataset(ph3py.mlp_dataset):
if not read_fc["fc3"]:
if use_pypolymlp:
run_pypolymlp_to_compute_forces(
ph3py,
mlp_params=mlp_params,
Expand Down
2 changes: 1 addition & 1 deletion phono3py/phonon3/displacement_fc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def get_bond_symmetry(
def get_least_orbits(atom_index, cell, site_symmetry, symprec=1e-5):
"""Find least orbits for a centering atom."""
orbits = _get_orbits(atom_index, cell, site_symmetry, symprec)
mapping = np.arange(cell.get_number_of_atoms())
mapping = np.arange(len(cell))

for i, orb in enumerate(orbits):
for num in np.unique(orb):
Expand Down
2 changes: 1 addition & 1 deletion phono3py/phonon3/reciprocal_to_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _reciprocal_to_normal(self, grid_triplet):
self._fc3_normal[i, j, k] = fc3_elem / fff

def _sum_in_atoms(self, band_indices, eigvecs):
num_atom = self._primitive.get_number_of_atoms()
num_atom = len(self._primitive)
(e1, e2, e3) = eigvecs
(b1, b2, b3) = band_indices

Expand Down

0 comments on commit b7fea1a

Please sign in to comment.