Skip to content

Commit

Permalink
enable ff_opt option to allow training data collection
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Jan 27, 2025
1 parent bec3a83 commit 49a1dd2
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 247 deletions.
2 changes: 1 addition & 1 deletion pyxtal/interface/ase_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def handler(signum, frame):

tag = 'False' if struc is None else 'True'
logger.info(f"Finishing {label} {tag}")
#signal.alarm(0) # Cancel the alarm
#signal.alarm(0) # Cancel the alarm
return struc #, eng, _fmax

class ASE_optimizer:
Expand Down
3 changes: 1 addition & 2 deletions pyxtal/optimize/DFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def _run(self, pool=None):
success_rate or None
"""
# Related to the FF optimization
N_added = 0
success_rate = 0
cur_survivals = [0] * self.N_pop # track the survivals
hist_best_xtals = [None] * self.N_pop
Expand Down Expand Up @@ -242,7 +241,7 @@ def _run(self, pool=None):

# Update the FF parameters if necessary
if self.ff_opt:
N_added = self.update_ff_paramters(cur_xtals, engs, N_added)
self.export_references(cur_xtals, engs)
else:
quit = False
if self.rank == 0:
Expand Down
1 change: 0 additions & 1 deletion pyxtal/optimize/QRS.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def _run(self, pool=None):
success_rate or None
"""
self.ref_volumes = []
N_added = 0
success_rate = 0
print(f"Rank {self.rank} starts QRS in {self.tag}")

Expand Down
3 changes: 1 addition & 2 deletions pyxtal/optimize/WFS.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def _run(self, pool=None):
"""

# Related to the FF optimization
N_added = 0
success_rate = 0
print(f"Rank {self.rank} starts WFS in {self.tag}")

Expand Down Expand Up @@ -219,7 +218,7 @@ def _run(self, pool=None):

# Update the FF parameters if necessary
if self.ff_opt:
N_added = self.update_ff_paramters(cur_xtals, engs, N_added)
self.export_references(cur_xtals, engs)
else:
quit = False
if self.rank == 0:
Expand Down
286 changes: 45 additions & 241 deletions pyxtal/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def __init__(
if self.rank == 0:
from pyocse.parameters import ForceFieldParameters
self.parameters = ForceFieldParameters(
self.smiles,
style=ff_style,
self.smiles,
style=ff_style,
ncpu=self.ncpu)
if self.ff_opt:
self.parameters.set_ref_evaluator('mace')
Expand Down Expand Up @@ -333,7 +333,7 @@ def __str__(self):
s += f"Mode : Sampling\n"
s += f"cif : {self.cif:s}\n"
if self.ff_opt:
s += "forcefield: On-the-fly\n"
s += "forcefield: Sample-training\n"
else:
s += "forcefield: Predefined\n"

Expand Down Expand Up @@ -478,265 +478,69 @@ def early_termination(self, success_rate):
return True
return False

def update_ff_paramters(self, xtals, engs, N_added):
def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5):
"""
Update the ff parameters
Add trainning data
Args:
xtals: a list of pyxtals
engs: a list of energies
N_added (int): the number of structures that have been added
N_min (int): minimum number of configs to add
dE (float): the cutoff energy value
FMSE (float): the cutoff Force MSE value
"""

gen = self.generation
N_max = min([int(self.N_pop * 0.6), 50])
ids = np.argsort(engs)
_xtals = self.select_xtals(xtals, ids, N_max)
print("Select structures for FF optimization", len(_xtals))

return self.ff_optimization(_xtals, N_added)

def ff_optimization(self, xtals, N_added, N_min=50, dE=2.5, FMSE=2.5):
"""
Optimize the current FF based on newly explored data
Args:
xtals: list of xtals from the current gen
N_added (int): the number of structures that have been added
N_min (int): minimum number of configurations to trigger the FF training
dE (float): the cutoff energy value
FMSE (float): the cutoff Force MSE value
"""
from pyocse.utils import reset_lammps_cell
from pyocse.parameters import compute_r2, get_lmp_efs

numMols = [xtal.numMols for xtal in xtals]
xtals = [xtal.to_ase(resort=False) for xtal in xtals]
numMols = [xtal.numMols for xtal in _xtals]
_xtals = [xtal.to_ase(resort=False) for xtal in _xtals]

gen = self.generation
# Initialize ff parameters and references
params_opt, err_dict = self.parameters.load_parameters(
self.ff_parameters)
# Initialize references
if os.path.exists(self.reference_file):
ref_dics = self.parameters.load_references(self.reference_file)
ref_dics = self.parameters.cut_references_by_error(
ref_dics, params_opt, dE=dE, FMSE=FMSE)
if self.ref_criteria is not None:
ref_dics = self.parameters.clean_ref_dics(
ref_dics, self.ref_criteria)
ref_dics = self.parameters.cut_references_by_error(
ref_dics, params_opt, dE=dE, FMSE=FMSE)
# self.parameters.generate_report(ref_dics, params_opt)
ref_dics, self.ref_criteria)
else:
ref_dics = []

# Add references
print("Current number of reference structures", len(ref_dics))
t0 = time()
if len(ref_dics) > 100: # no fit if ref_dics is large
# Here we find the lowest engs and select only low-E struc
ref_engs = [ref_dic["energy"] / ref_dic["replicate"]
for ref_dic in ref_dics]
ref_e2 = np.array(ref_engs).min()
print("Min Reference Energy", ref_e2)

if len(err_dict) == 0:
# update the offset if necessary
_, params_opt = self.parameters.optimize_offset(
ref_dics, params_opt)
results = self.parameters.evaluate_multi_references(
ref_dics, params_opt, 1000, 1000)
(ff_values, ref_values, rmse_values, r2_values) = results
err_dict = {"rmse_values": rmse_values}

_ref_dics = []
rmse_values = err_dict["rmse_values"]
lmp_in = self.parameters.ff.get_lammps_in()
self.parameters.ase_templates = {}
self.lmp_dat = {}
ff_engs, ff_fors, ff_strs = [], [], []
rf_engs, rf_fors, rf_strs = [], [], []

for numMol, xtal in zip(numMols, xtals):
struc = reset_lammps_cell(xtal)
lmp_struc, lmp_dat = self.parameters.get_lmp_input_from_structure(
struc, numMol, set_template=False)
replicate = len(lmp_struc.atoms) / \
self.parameters.natoms_per_unit

try:
# ; print('Debug KONTIQ', struc, e1)
e1, f1, s1 = get_lmp_efs(lmp_struc, lmp_in, lmp_dat)
except:
e1 = self.E_max

# filter very high energy structures
if e1 < 1000:
struc.set_calculator(self.parameters.calculator)
e2 = struc.get_potential_energy()
e2 /= replicate
# Ignore very high energy structures
if e2 < ref_e2 + self.eng_cutoff:
e1 /= replicate
f2 = struc.get_forces()
s2 = struc.get_stress()
struc.set_calculator()
e_err = abs(e1 - e2 + params_opt[-1])
f_err = np.sqrt(
((f1.flatten() - f2.flatten()) ** 2).mean())
s_err = np.sqrt(((s1 - s2) ** 2).mean())
e_check = e_err < 0.5 * rmse_values[0]
f_check = f_err < 1.0 * rmse_values[1]
s_check = s_err < 1.0 * rmse_values[2]
strs = f"Errors of csp structure in gen{gen:3d} "
strs += f"{e_err:.4f} {f_err:.4f} {s_err:.4f} "
strs += f"{e1:8.4f} {e2:8.4f}"
print(strs, e_check, f_check, s_check)

# avoid very unphysical structures
if e_err < 4.0 and f_err < 4.0:
ff_engs.append(e1 + params_opt[-1])
rf_engs.append(e2)
ff_fors.extend(f1.flatten())
rf_fors.extend(f2.flatten())
ff_strs.extend(s1.flatten())
rf_strs.extend(s2.flatten())

if False in [e_check, f_check, s_check]:
_ref_dic = {
"structure": struc,
"energy": e2 * replicate,
"forces": f2,
"stress": s2,
"replicate": replicate,
"options": [True, not f_check, True],
"tag": "CSP",
"numMols": numMol,
}
_ref_dics.append(_ref_dic)
else:
print("Ignore the structure due to high energy", e2)

# QZ: Output FF performances MSE, R2 for the selected structures
if len(_ref_dics) == 0:
print("There is a serious problem in depositing high energy")
raise ValueError("The program needs to stop here")

if self.ref_criteria is not None:
_ref_dics = self.parameters.clean_ref_dics(
_ref_dics, self.ref_criteria)

# print("Added {:d} new reference structures into training".format(len(_ref_dics)))
print("FF performances")
ff_engs = np.array(ff_engs)
rf_engs = np.array(rf_engs)
ff_fors = np.array(ff_fors)
rf_fors = np.array(rf_fors)
ff_strs = np.array(ff_strs)
rf_strs = np.array(rf_strs)
r2_engs = compute_r2(ff_engs, rf_engs)
r2_fors = compute_r2(ff_fors, rf_fors)
r2_strs = compute_r2(ff_strs, rf_strs)
mse_engs = np.sqrt(np.mean((ff_engs - rf_engs) ** 2))
mse_fors = np.sqrt(np.mean((ff_fors - rf_fors) ** 2))
mse_strs = np.sqrt(np.mean((ff_strs - rf_strs) ** 2))

print(f"R2 {r2_engs:8.4f} {r2_fors:8.4f} {r2_strs:8.4f}")
print(f"RMSE {mse_engs:8.4f} {mse_fors:8.4f} {mse_strs:8.4f}")
# self.parameters.generate_report(_ref_dics, params_opt)
# import sys; sys.exit()
else:
# reduce the number of structures to save some time
N_selected = min([N_min, self.ncpu])
print("Create the reference data by augmentation", N_selected)
if len(xtals) >= N_selected:
ids = self.random_state.choice(
list(range(len(xtals))), N_selected)
xtals = [xtals[id] for id in ids]
numMols = [numMols[id] for id in ids]

_ref_dics = self.parameters.add_multi_references(
xtals,
numMols,
augment=True,
steps=20, # 50,
N_vibs=1,
logfile="ase.log",
)
if len(ref_dics) == 0:
_, params_opt = self.parameters.optimize_offset(
_ref_dics, params_opt)
self.parameters.generate_report(ref_dics, params_opt)
# import sys; sys.exit()

N_selected = min([N_min, self.ncpu])
print("Current number of reference structures", len(ref_dics))
print("Create the reference data by augmentation", N_selected)
if len(_xtals) >= N_selected:
ids = self.random_state.choice(list(range(len(_xtals))), N_selected)
_xtals = [_xtals[id] for id in ids]
numMols = [numMols[id] for id in ids]

_ref_dics = self.parameters.add_multi_references(_xtals,
numMols,
augment=True,
steps=20, #50,
N_vibs=1,
logfile="ase.log")
ref_dics.extend(_ref_dics)
print(
f"Add {len(_ref_dics):d} references in {(time() - t0) / 60:.2f} min")
self.parameters.export_references(ref_dics, self.reference_file)

# Optimize ff parameters if we get enough number of configurations
N_added += len(_ref_dics)
if N_added < N_min:
print("Do not update ff, the current number of configurations is", N_added)
else:
t0 = time()
_, params_opt = self.parameters.optimize_offset(
ref_dics, params_opt)

for data in [
(["bond", "angle", "proper"], 50),
(["proper", "vdW", "charge"], 50),
(["bond", "angle", "proper", "vdW", "charge"], 50),
]:
(terms, steps) = data

# Actual optimization
opt_dict = self.parameters.get_opt_dict(
terms, None, params_opt)
x, fun, values, it = self.parameters.optimize_global(
ref_dics, opt_dict, params_opt, obj="R2", t0=0.1, steps=25
)

params_opt = self.parameters.set_sub_parameters(
values, terms, params_opt)

opt_dict = self.parameters.get_opt_dict(
terms, None, params_opt)

x, fun, values, it = self.parameters.optimize_local(
ref_dics, opt_dict, params_opt, obj="R2", steps=steps)

params_opt = self.parameters.set_sub_parameters(
values, terms, params_opt)
_, params_opt = self.parameters.optimize_offset(
ref_dics, params_opt)

# To add Early termination

t = (time() - t0) / 60
print(f"FF optimization {t:.2f} min ", fun)
# Reset N_added to 0
N_added = 0

# Export FF performances
if gen < 10:
gen_prefix = "gen_00" + str(gen)
elif gen < 100:
gen_prefix = "gen_0" + str(gen)
else:
gen_prefix = "gen_" + str(gen)
t1 = (time() - t0) / 60
print(f"Add {len(_ref_dics)} references in {t1:.2f} min")

performance_fig = f"{self.workdir:s}/FF_performance_{gen_prefix:s}.png"
errs = self.parameters.plot_ff_results(performance_fig, ref_dics, [
params_opt], labels=gen_prefix)

param_fig = f"{self.workdir:s}/parameters_{gen_prefix:s}.png"
self.parameters.plot_ff_parameters(param_fig, [params_opt])

# Save parameters
self.parameters.export_parameters(
self.ff_parameters, params_opt, errs[0])
self._prepare_chm_info(params_opt)
# self.parameters.generate_report(ref_dics, params_opt)
# Export FF performances and references
# Todo: as appending way
self.parameters.export_references(ref_dics, self.reference_file)
gen_prefix = self.get_label(gen, 'gen_')
performance_fig = f"{self.workdir}/FF_performance_{gen_prefix}.png"
params, _ = self.parameters.load_parameters(self.ff_parameters)
self.parameters.plot_ff_results(performance_fig,
ref_dics,
[params],
labels=gen_prefix)

return N_added

def _prepare_chm_info(self, params0, params1=None, folder="calc", suffix="pyxtal0"):
"""
Expand Down Expand Up @@ -773,13 +577,13 @@ def _prepare_chm_info(self, params0, params1=None, folder="calc", suffix="pyxtal
# Info
return ase_with_ff.get_atom_info()

def get_label(self, i):
def get_label(self, i, label='cpu'):
if i < 10:
folder = f"cpu00{i}"
folder = f"{label}00{i}"
elif i < 100:
folder = f"cpu0{i}"
folder = f"{label}0{i}"
else:
folder = f"cpu0{i}"
folder = f"{label}0{i}"
return folder

def print_matches(self, header=None):
Expand Down

0 comments on commit 49a1dd2

Please sign in to comment.