diff --git a/phono3py/api_phono3py.py b/phono3py/api_phono3py.py index 00d54efe..41f971cb 100644 --- a/phono3py/api_phono3py.py +++ b/phono3py/api_phono3py.py @@ -60,6 +60,7 @@ PypolymlpParams, develop_polymlp, evalulate_polymlp, + load_polymlp, parse_mlp_params, ) from phonopy.structure.atoms import PhonopyAtoms @@ -2226,8 +2227,12 @@ def develop_mlp( verbose=self._log_level - 1 > 0, ) + def load_mlp(self, filename: str = "pypolymlp.mlp"): + """Load machine learning potential of pypolymlp.""" + self._mlp = load_polymlp(filename=filename) + def evaluate_mlp(self): - """Evaluate the machine learning potential of pypolymlp. + """Evaluate machine learning potential of pypolymlp. This method calculates the supercell energies and forces from the MLP for the displacements in self._dataset of type 2. The results are stored diff --git a/phono3py/cui/create_force_constants.py b/phono3py/cui/create_force_constants.py index 7d33ec6c..99eb892c 100644 --- a/phono3py/cui/create_force_constants.py +++ b/phono3py/cui/create_force_constants.py @@ -502,7 +502,8 @@ def _read_dataset_fc3( file_exists(e.filename, log_level=log_level) if use_pypolymlp: - phono3py.mlp_dataset = dataset + if forces_in_dataset(dataset): + phono3py.mlp_dataset = dataset run_pypolymlp_to_compute_forces( phono3py, mlp_params, @@ -521,6 +522,7 @@ def run_pypolymlp_to_compute_forces( displacement_distance: Optional[float] = None, number_of_snapshots: Optional[int] = None, random_seed: Optional[int] = None, + mlp_filename: str = "pypolymlp.mlp", log_level: int = 0, ): """Run pypolymlp to compute forces.""" @@ -536,10 +538,18 @@ def run_pypolymlp_to_compute_forces( print(f" {k}: {v}") if log_level > 1: print("") - if log_level: - print("Developing MLPs by pypolymlp...", flush=True) - ph3py.develop_mlp(params=mlp_params) + if forces_in_dataset(ph3py.mlp_dataset): + if log_level: + print("Developing MLPs by pypolymlp...", flush=True) + ph3py.develop_mlp(params=mlp_params) + else: + if pathlib.Path(mlp_filename).exists(): + if log_level: + print(f'Load MLPs from "{mlp_filename}".') + ph3py.load_mlp(mlp_filename) + else: + raise RuntimeError(f'"{mlp_filename}" is not found.') if log_level: print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True) @@ -577,9 +587,6 @@ def run_pypolymlp_to_compute_forces( flush=True, ) - if ph3py.mlp_dataset is None: - msg = "mlp_dataset has to be set before calling this method." - raise RuntimeError(msg) if ph3py.supercells_with_displacements is None: raise RuntimeError("Displacements are not set. Run generate_displacements.") diff --git a/phono3py/cui/load.py b/phono3py/cui/load.py index f87d6e2c..e4dda9d0 100644 --- a/phono3py/cui/load.py +++ b/phono3py/cui/load.py @@ -416,7 +416,8 @@ def set_dataset_and_force_constants( ) if not read_fc["fc3"]: if use_pypolymlp: - ph3py.mlp_dataset = dataset + if forces_in_dataset(dataset): + ph3py.mlp_dataset = dataset else: ph3py.dataset = dataset read_fc["fc2"], phonon_dataset = _get_dataset_phonon_dataset_or_fc2( @@ -461,8 +462,8 @@ def compute_force_constants_from_datasets( """ fc3_calculator = extract_fc2_fc3_calculators(fc_calculator, 3) fc2_calculator = extract_fc2_fc3_calculators(fc_calculator, 2) - if not read_fc["fc3"] and (ph3py.dataset or ph3py.mlp_dataset): - if use_pypolymlp and forces_in_dataset(ph3py.mlp_dataset): + if not read_fc["fc3"]: + if use_pypolymlp: run_pypolymlp_to_compute_forces( ph3py, mlp_params=mlp_params, diff --git a/phono3py/phonon3/displacement_fc3.py b/phono3py/phonon3/displacement_fc3.py index f769c900..9afd6eac 100644 --- a/phono3py/phonon3/displacement_fc3.py +++ b/phono3py/phonon3/displacement_fc3.py @@ -272,7 +272,7 @@ def get_bond_symmetry( def get_least_orbits(atom_index, cell, site_symmetry, symprec=1e-5): """Find least orbits for a centering atom.""" orbits = _get_orbits(atom_index, cell, site_symmetry, symprec) - mapping = np.arange(cell.get_number_of_atoms()) + mapping = np.arange(len(cell)) for i, orb in enumerate(orbits): for num in np.unique(orb): diff --git a/phono3py/phonon3/reciprocal_to_normal.py b/phono3py/phonon3/reciprocal_to_normal.py index 9cf2970e..4db52090 100644 --- a/phono3py/phonon3/reciprocal_to_normal.py +++ b/phono3py/phonon3/reciprocal_to_normal.py @@ -88,7 +88,7 @@ def _reciprocal_to_normal(self, grid_triplet): self._fc3_normal[i, j, k] = fc3_elem / fff def _sum_in_atoms(self, band_indices, eigvecs): - num_atom = self._primitive.get_number_of_atoms() + num_atom = len(self._primitive) (e1, e2, e3) = eigvecs (b1, b2, b3) = band_indices