-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolver_general.py
174 lines (135 loc) · 7.57 KB
/
solver_general.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import torch
import sys
import os
import time
from epde.integrate import OdeintAdapter
from tedeous.data import Equation
from tedeous.model import Model
from tedeous.callbacks import early_stopping, plot, cache
from tedeous.optimizers.optimizer import Optimizer
from tedeous.device import solver_device, check_device
from tedeous.models import mat_model
from epde.interface.interface import EpdeSearch
from epde.interface.equation_translator import translate_equation
from func.transition_bs import text_form_of_equation, solver_form_to_text_form
from func import transition_bs as transform
import tkinter as tk
from tkinter import filedialog, messagebox
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
sys.path.append('../')
def load_result(title):
root = tk.Tk()
root.withdraw()
solver_result_dir = f'data/{title}/solver_result'
while True:
try:
file_path = filedialog.askopenfilename(
initialdir=solver_result_dir,
title="Select file",
filetypes=(("PyTorch files", "*.pt"), ("all files", "*.*"))
)
root.withdraw()
if file_path:
set_solutions = torch.load(file_path)
print(f"The file '{file_path}' has been successfully uploaded.")
root.destroy()
return set_solutions
else:
print("File selection was cancelled.")
return None
except Exception as e:
print(f"Error during file upload: {e}")
retry = messagebox.askretrycancel("File Load Error", f"Failed to load file: {e}\nRetry?")
if not retry:
root.destroy()
return None
def solver_equations(cfg, domain, params_full, b_conds, equations, epde_obj: EpdeSearch = False, title=None):
# solver_device('cuda')
torch.set_default_dtype(torch.float32)
if not (os.path.exists(f'data/{title}/solver_result')):
os.mkdir(f'data/{title}/solver_result')
if cfg.params["glob_solver"]["load_result"]:
return load_result(title)
dim = cfg.params["global_config"]["dimensionality"] + 1 # (starts from 0 - [t,], 1 - [t, x], 2 - [t, x, y])
k_variable_names = len(cfg.params["fit"]["variable_names"])
set_solutions = []
errors = []
for number, equation_i in enumerate(equations):
start = time.time()
eq_solver = transform.solver_view(equation_i, cfg)
equation = Equation()
if k_variable_names > 1: # if the system, when we get the list from transform.solver_view
for eq_i in eq_solver:
equation.add(eq_i)
else:
equation.add(eq_solver)
if cfg.params["glob_solver"]["type"] != 'odeint':
if cfg.params["glob_solver"]["mode"] == 'mat':
net = mat_model(domain, equation)
else: # for variant mode = "NN" and "autograd"
net = torch.nn.Sequential(
torch.nn.Linear(dim, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, k_variable_names)
)
model = Model(net, domain, equation, b_conds)
model.compile(mode=cfg.params["glob_solver"]["mode"],
lambda_operator=cfg.params['Optimizer']['lambda_operator'],
lambda_bound=cfg.params['Optimizer']['lambda_bound'])
cb_es = early_stopping.EarlyStopping(eps=cfg.params['StopCriterion']['eps'],
no_improvement_patience=cfg.params['StopCriterion']['no_improvement_patience'],
patience=cfg.params['StopCriterion']['patience'],
verbose=cfg.params['StopCriterion']['verbose'],
info_string_every=cfg.params['StopCriterion']['print_every'])
cb_cache = cache.Cache(cache_dir=cfg.params['Cache']['cache_dir'],
cache_verbose=cfg.params['Cache']['cache_verbose'],
model_randomize_parameter=cfg.params['Cache']['model_randomize_parameter'])
cb_plots = plot.Plots(save_every=cfg.params["Plot"]["step_plot_save"],
print_every=cfg.params["Plot"]["step_plot_print"],
img_dir=cfg.params["Plot"]["image_save_dir"])
optimizer = Optimizer(optimizer=cfg.params['Optimizer']['optimizer'],
params={'lr': cfg.params['Optimizer']['learning_rate']})
model.train(optimizer, epochs=cfg.params['Optimizer']['epochs'], save_model=cfg.params['Cache']['save_always'], callbacks=[cb_es, cb_plots, cb_cache])
end = time.time()
print(f'Time = {end - start}')
grid = domain.build(cfg.params["glob_solver"]["mode"])
solution_function = net if cfg.params["glob_solver"]["mode"] == "mat" else net(grid)
solution_function = solution_function.reshape(*[len(i) for i in params_full]).detach().cpu().numpy() if dim > 1 else solution_function.detach().cpu().numpy()
else:
dict_odeint = {}
for i, var in enumerate(cfg.params["fit"]["variable_names"]):
text_form = solver_form_to_text_form(eq_solver[i], cfg)
dict_odeint[var] = text_form
eq_translated = translate_equation(dict_odeint, epde_obj.pool, cfg.params["fit"]["variable_names"])
model = OdeintAdapter() # method='LSODA', 'Radau'
try:
_, solution_function = model.solve_epde_system(system=eq_translated, grids=params_full, mode='autograd')
except Exception as e:
errors.append((number, equation_i, str(e)))
print(f"Error: {e}. The equation {equation_i} is unsolvable...")
continue
solution_function = solution_function.reshape(*[len(i) for i in
params_full]) if dim > 1 else solution_function
if solution_function.shape[0] != [len(i) for i in params_full][0]:
print(solution_function.shape)
print(number, equation_i)
print('--------------------------')
continue
if not len(set_solutions):
set_solutions = [solution_function]
else:
set_solutions.append(solution_function)
# To save temporary solutions
torch.save(np.array(set_solutions), f'data/{title}/solver_result/file_u_main_{list(np.array(set_solutions).shape)}_{cfg.params["glob_solver"]["mode"]}_{cfg.params["global_config"]["variance_arr"]}.pt')
set_solutions = np.array(set_solutions)
number_of_files = int(len(os.listdir(path=f"data/{title}/solver_result/")))
if os.path.exists(f'data/{title}/solver_result/file_u_main_{list(set_solutions.shape)}_{cfg.params["glob_solver"]["mode"]}_{cfg.params["global_config"]["variance_arr"]}.pt'):
torch.save(set_solutions, f'data/{title}/solver_result/file_u_main_{list(set_solutions.shape)}_{cfg.params["glob_solver"]["mode"]}_{cfg.params["global_config"]["variance_arr"]}_{number_of_files}.pt')
else:
torch.save(set_solutions, f'data/{title}/solver_result/file_u_main_{list(set_solutions.shape)}_{cfg.params["glob_solver"]["mode"]}_{cfg.params["global_config"]["variance_arr"]}.pt')
return set_solutions