From 42bc6247ba53e1054b72e657230e93c8f0f920ee Mon Sep 17 00:00:00 2001 From: SuperSashka Date: Wed, 4 Sep 2024 11:26:35 +0300 Subject: [PATCH] More stable NGD --- examples/example_Lotka_Volterra_paper.py | 36 +- examples/example_Lotka_Volterra_paper_NGD.py | 330 +++++++++++++++++++ tedeous/callbacks/early_stopping.py | 14 +- tedeous/model.py | 6 +- tedeous/optimizers/ngd.py | 102 +++++- 5 files changed, 458 insertions(+), 30 deletions(-) create mode 100644 examples/example_Lotka_Volterra_paper_NGD.py diff --git a/examples/example_Lotka_Volterra_paper.py b/examples/example_Lotka_Volterra_paper.py index c782f7ee..12ed954e 100644 --- a/examples/example_Lotka_Volterra_paper.py +++ b/examples/example_Lotka_Volterra_paper.py @@ -24,7 +24,7 @@ from tedeous.callbacks import cache, early_stopping, plot from tedeous.optimizers.optimizer import Optimizer from tedeous.device import solver_device, check_device, device_type - +from tedeous.models import Fourier_embedding alpha = 20. beta = 20. @@ -47,22 +47,18 @@ def __init__(self): self.width_out=[2] # Shared layers (base network) - self.shared_fc1 = torch.nn.Linear(1, 100) # Input size of 1 (for t) - self.shared_fc2 = torch.nn.Linear(100, 100) - self.shared_fc3 = torch.nn.Linear(100, 100) - self.shared_fc4 = torch.nn.Linear(100, 100) + self.shared_fc1 = torch.nn.Linear(1, 64) # Input size of 1 (for t) + self.shared_fc2 = torch.nn.Linear(64, 32) # Output head for Process 1 - self.process1_fc = torch.nn.Linear(100, 1) + self.process1_fc = torch.nn.Linear(32, 1) # Output head for Process 2 - self.process2_fc = torch.nn.Linear(100, 1) + self.process2_fc = torch.nn.Linear(32, 1) def forward(self, t): # Shared layers forward pass x = torch.tanh(self.shared_fc1(t)) x = torch.tanh(self.shared_fc2(x)) - x = torch.tanh(self.shared_fc3(x)) - x = torch.tanh(self.shared_fc4(x)) # Process 1 output head process1_out = self.process1_fc(x) @@ -82,13 +78,19 @@ def Lotka_experiment(grid_res, CACHE): exp_dict_list = [] solver_device('gpu') - #net = torch.nn.Sequential( - # torch.nn.Linear(1, 32), - # torch.nn.Tanh(), - # torch.nn.Linear(32, 32), - # torch.nn.Tanh(), - # torch.nn.Linear(32, 2) - #) + FFL = Fourier_embedding(L=[5], M=[2]) + + out = FFL.out_features + + + net = torch.nn.Sequential( + FFL, + torch.nn.Linear(out, 32), + torch.nn.Tanh(), + torch.nn.Linear(32, 32), + torch.nn.Tanh(), + torch.nn.Linear(32, 2) + ) net=MultiOutputModel() @@ -180,7 +182,7 @@ def Lotka_experiment(grid_res, CACHE): optimizer = Optimizer('Adam', {'lr': 1e-4}) - model.train(optimizer, 2e5, save_model=True, callbacks=[cb_es, cb_plots]) + model.train(optimizer, 1e4, save_model=True, callbacks=[cb_es, cb_plots]) end = time.time() diff --git a/examples/example_Lotka_Volterra_paper_NGD.py b/examples/example_Lotka_Volterra_paper_NGD.py new file mode 100644 index 00000000..39b928ba --- /dev/null +++ b/examples/example_Lotka_Volterra_paper_NGD.py @@ -0,0 +1,330 @@ +# Lotka-Volterra equations also known as predator-prey equations, describe the variation in populations +# of two species which interact via predation. +# For example, wolves (predators) and deer (prey). This is a classical model to represent the dynamic of two populations. + +# Let αlpha > 0, beta > 0, delta > 0 and gamma > 0 . The system is given by + +# dx/dt = x(alpha-beta*y) +# dy/dt = y(-delta+gamma*x) + +# Where 'x' represents prey population and 'y' predators population. It’s a system of first-order ordinary differential equations. +import torch +import numpy as np +import matplotlib.pyplot as plt +from scipy import integrate +import time +import os +import sys + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' +sys.path.append(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) + +from tedeous.data import Domain, Conditions, Equation +from tedeous.model import Model +from tedeous.callbacks import cache, early_stopping, plot, adaptive_lambda +from tedeous.optimizers.optimizer import Optimizer +from tedeous.device import solver_device, check_device, device_type +from tedeous.models import mat_model, Fourier_embedding,FourierNN + +alpha = 20. +beta = 20. +delta = 20. +gamma = 20. +x0 = 4. +y0 = 2. +t0 = 0. +tmax = 1. + + +from copy import deepcopy + + +def train_net(net,grid,exact): + + exact=torch.Tensor(exact).float() + + optimizer = torch.optim.Adam(net.parameters(), lr=0.001) + + t0=torch.Tensor([0.]) + x0 = 4. + y0 = 2. + + loss = torch.mean(torch.square(net(grid) - exact))+100*torch.mean(torch.square(net(t0)-torch.Tensor([x0,y0]))) + + def closure(): + optimizer.zero_grad() + loss = torch.mean(torch.square(net(grid) - exact))+100*torch.mean(torch.square(net(t0)-torch.Tensor([x0,y0]))) + loss.backward() + return loss + + t = 0 + while loss > 1e-5 and t < 1e5: + optimizer.step(closure) + loss = torch.mean(torch.square(net(grid) - exact))+100*torch.mean(torch.square(net(t0)-torch.Tensor([x0,y0]))) + t += 1 + if t %1000 ==0: + print('Interpolate from exact t={}, loss={}'.format(t, loss)) + + return net + + +# Define the model +class MultiOutputModel(torch.nn.Module): + def __init__(self): + super(MultiOutputModel, self).__init__() + + self.width_out=[2] + + # Shared layers (base network) + self.shared_fc1 = torch.nn.Linear(1, 64) # Input size of 1 (for t) + self.shared_fc2 = torch.nn.Linear(64, 32) + # Output head for Process 1 + self.process1_fc = torch.nn.Linear(32, 1) + + # Output head for Process 2 + self.process2_fc = torch.nn.Linear(32, 1) + + def forward(self, t): + # Shared layers forward pass + x = torch.tanh(self.shared_fc1(t)) + x = torch.tanh(self.shared_fc2(x)) + # Process 1 output head + process1_out = self.process1_fc(x) + + # Process 2 output head + process2_out = self.process2_fc(x) + + out=torch.cat((process1_out, process2_out), dim=1) + + return out + +# Initialize the model +#model = + + +def Lotka_experiment(grid_res, CACHE): + + exp_dict_list = [] + solver_device('gpu') + + FFL = Fourier_embedding(L=[1/4], M=[4]) + + out = FFL.out_features + + + net = torch.nn.Sequential( + FFL, + torch.nn.Linear(out, 32), + torch.nn.Tanh(), + torch.nn.Linear(32, 32), + torch.nn.Tanh(), + torch.nn.Linear(32, 2) + ) + + #net=MultiOutputModel() + + #net = FourierNN([512, 512, 512, 512, 2], [15], [7]) + + #net = torch.nn.Sequential( + # torch.nn.Linear(1, 32), + # torch.nn.Tanh(), + # torch.nn.Linear(32, 32), + # torch.nn.Tanh(), + # torch.nn.Linear(32, 2) + #) + + #def weights_init(m): + # if isinstance(m, torch.nn.Linear): + # torch.nn.init.xavier_normal_(m.weight, gain=1.0) + # #torch.nn.init.zero_(m.bias) + + #net.apply(weights_init) + + + def exact(): + # scipy.integrate solution of Lotka_Volterra equations and comparison with NN results + + def deriv(X, t, alpha, beta, delta, gamma): + x, y = X + dotx = x * (alpha - beta * y) + doty = y * (-delta + gamma * x) + return np.array([dotx, doty]) + + t = np.linspace(0, tmax, grid_res+1) + + X0 = [x0, y0] + res = integrate.odeint(deriv, X0, t, args = (alpha, beta, delta, gamma)) + x, y = res.T + return np.hstack((x.reshape(-1,1),y.reshape(-1,1))) + + u_exact = exact() + + #net=train_net(net,torch.from_numpy(np.linspace(0, 1, grid_res+1)).reshape(-1,1).float(),u_exact) + + + domain = Domain() + domain.variable('t', [0, tmax], grid_res) + + boundaries = Conditions() + #initial conditions + boundaries.dirichlet({'t': 0}, value=x0, var=0) + boundaries.dirichlet({'t': 0}, value=y0, var=1) + + #equation system + # eq1: dx/dt = x(alpha-beta*y) + # eq2: dy/dt = y(-delta+gamma*x) + + # x var: 0 + # y var:1 + + equation = Equation() + + eq1 = { + 'dx/dt':{ + 'coeff': 1, + 'term': [0], + 'pow': 1, + 'var': [0] + }, + '-x*alpha':{ + 'coeff': -alpha, + 'term': [None], + 'pow': 1, + 'var': [0] + }, + '+beta*x*y':{ + 'coeff': beta, + 'term': [[None], [None]], + 'pow': [1, 1], + 'var': [0, 1] + } + } + + eq2 = { + 'dy/dt':{ + 'coeff': 1, + 'term': [0], + 'pow': 1, + 'var': [1] + }, + '+y*delta':{ + 'coeff': delta, + 'term': [None], + 'pow': 1, + 'var': [1] + }, + '-gamma*x*y':{ + 'coeff': -gamma, + 'term': [[None], [None]], + 'pow': [1, 1], + 'var': [0, 1] + } + } + + equation.add(eq1) + equation.add(eq2) + + + + model = Model(net, domain, equation, boundaries) + + model.compile("autograd", lambda_operator=1, lambda_bound=100) + + img_dir=os.path.join(os.path.dirname( __file__ ), 'img_Lotka_Volterra_paper_NGD') + + start = time.time() + + cb_es = early_stopping.EarlyStopping(eps=5e-6, + loss_window=1000, + no_improvement_patience=1000, + patience=5, + info_string_every=100, + randomize_parameter=1e-5, + save_best=True) + + cb_plots = plot.Plots(save_every=1000, print_every=None, img_dir=img_dir) + + cb_lambda = adaptive_lambda.AdaptiveLambda() + + #cb_cache = cache.Cache(cache_verbose=True, model_randomize_parameter=1e-5) + + optimizer = Optimizer('Adam', {'lr': 1e-4}) + + model.train(optimizer, 5e5, callbacks=[cb_es, cb_plots]) + + #model = Model(net, domain, equation, boundaries) + + model.compile("autograd", lambda_operator=1, lambda_bound=100) + + + optimizer = Optimizer('NGD', {'grid_steps_number': 20}) + + cb_es = early_stopping.EarlyStopping(eps=5e-6, + loss_window=1000, + no_improvement_patience=1000, + patience=5, + info_string_every=100, + randomize_parameter=1e-5, + save_best=True) + + model.train(optimizer, 2e3, callbacks=[cb_es, cb_plots]) + + end = time.time() + + grid = domain.build('NN') + + + u_exact = torch.from_numpy(u_exact) + + prediction=net(grid) + + prediction_np=prediction.cpu().detach().numpy() + u_exact_np=u_exact.cpu().detach().numpy() + + + error_rmse = np.sqrt(np.mean((prediction_np-u_exact_np)**2)) + + exp_dict_list.append({'grid_res':grid_res,'time':end - start,'RMSE':error_rmse,'type':'Lotka_eqn','cache':CACHE}) + + print('Time taken {}= {}'.format(grid_res, end - start)) + print('RMSE {}= {}'.format(grid_res, error_rmse)) + + #t = domain.variable_dict['t'] + grid = domain.build('NN') + + t = np.linspace(0, 1, grid_res+1) + + plt.figure() + plt.grid() + plt.title("odeint and NN methods comparing") + plt.plot(t, u_exact[:,0].detach().numpy().reshape(-1), '+', label = 'preys_odeint') + plt.plot(t, u_exact[:,1].detach().numpy().reshape(-1), '*', label = "predators_odeint") + plt.plot(grid.cpu(), net(check_device(grid))[:,0].cpu().detach().numpy().reshape(-1), label='preys_NN') + plt.plot(grid.cpu(), net(check_device(grid))[:,1].cpu().detach().numpy().reshape(-1), label='predators_NN') + plt.xlabel('Time t, [days]') + plt.ylabel('Population') + plt.legend(loc='upper right') + plt.savefig(os.path.join(img_dir,'compare_{}.png'.format(grid_res))) + + + return exp_dict_list + +nruns=1 + +exp_dict_list=[] + +CACHE=False + +for grid_res in range(500,1001,100): + for _ in range(nruns): + exp_dict_list.append(Lotka_experiment(grid_res,CACHE)) + + + +#import pandas as pd + +#exp_dict_list_flatten = [item for sublist in exp_dict_list for item in sublist] +#df=pd.DataFrame(exp_dict_list_flatten) +#df.boxplot(by='grid_res',column='time',fontsize=42,figsize=(20,10)) +#df.boxplot(by='grid_res',column='RMSE',fontsize=42,figsize=(20,10),showfliers=False) +#df.to_csv('benchmarking_data/Lotka_experiment_50_90_cache={}.csv'.format(str(CACHE))) \ No newline at end of file diff --git a/tedeous/callbacks/early_stopping.py b/tedeous/callbacks/early_stopping.py index 19fd8ed0..9402b12c 100644 --- a/tedeous/callbacks/early_stopping.py +++ b/tedeous/callbacks/early_stopping.py @@ -18,7 +18,8 @@ def __init__(self, normalized_loss: bool = False, randomize_parameter: float = 1e-5, info_string_every: Union[int, None] = None, - verbose: bool = True + verbose: bool = True, + save_best: bool = False ): """_summary_ @@ -36,6 +37,7 @@ def __init__(self, info_string_every (Union[int, None], optional): prints the loss state after every *int* step. Defaults to None. verbose (bool, optional): print or not info about loss and current state of stopping criteria. Defaults to True. + save_best (bool, optional): model with least loss is saved during the training and returned at the end as a result """ super().__init__() self.eps = eps @@ -49,6 +51,10 @@ def __init__(self, self._r = create_random_fn(randomize_parameter) self.info_string_every = info_string_every if info_string_every is not None else np.inf self.verbose = verbose + self.save_best=save_best + self.best_model=None + + def _line_create(self): """ Approximating last_loss list (len(last_loss)=loss_oscillation_window) by the line. @@ -78,6 +84,8 @@ def _patience_check(self): self._stop_dings += 1 self._t_imp_start = self.t if self.mode in ('NN', 'autograd'): + if self.save_best: + self.model.net=self.best_model self.model.net.apply(self._r) self._check = 'patience_check' @@ -120,12 +128,16 @@ def on_epoch_end(self, logs=None): if self.model.cur_loss < self.model.min_loss: self.model.min_loss = self.model.cur_loss + if self.save_best: + self.best_model=self.model.net self._t_imp_start = self.t if self.verbose: self.verbose_print() if self._stop_dings >= self.patience: self.model.stop_training = True + if self.save_best: + self.model.net=self.best_model self._check = None def on_epoch_begin(self, logs=None): diff --git a/tedeous/model.py b/tedeous/model.py index 53896d9c..d43c5d28 100644 --- a/tedeous/model.py +++ b/tedeous/model.py @@ -92,9 +92,9 @@ def compile( self.equation_cls = Operator_bcond_preproc(grid, operator, bconds, h=h, inner_order=inner_order, boundary_order=boundary_order).set_strategy(mode) - - if len(grid) None: """ Update models paramters by natural gradient. @@ -73,6 +75,31 @@ def jacobian() -> torch.Tensor: J = jacobian() return 1.0 / len(residuals) * J.T @ J + + + def gram_factory_cpu(self, residuals: torch.Tensor) -> torch.Tensor: + """ Make Gram matrice. + + Args: + residuals (callable): PDE residual. + + Returns: + torch.Tensor: Gram matrice. + """ + # Make Gram matrice. + def jacobian() -> torch.Tensor: + jac = [] + for l in residuals: + j = torch.autograd.grad(l, self.params, retain_graph=True, allow_unused=True) + j = replace_none_by_zero(j) + j = parameters_to_vector(j).reshape(1, -1) + jac.append(j) + return torch.cat(jac) + + J = jacobian().cpu() + return 1.0 / len(residuals) * J.T @ J + + def torch_cuda_lstsq(self, A: torch.Tensor, B: torch.Tensor, tol: float = None) -> torch.Tensor: """ Find lstsq (least-squares solution) for torch.tensor cuda. @@ -95,6 +122,22 @@ def torch_cuda_lstsq(self, A: torch.Tensor, B: torch.Tensor, tol: float = None) SpinvUhB = Spinv * UhB return Vh.adjoint() @ SpinvUhB + + + def numpy_lstsq(self, A: torch.Tensor, B: torch.Tensor, rcond: float = None) -> torch.Tensor: + + A = A.detach().cpu().numpy() + B = B.detach().cpu().numpy() + + f_nat_grad = np.linalg.lstsq(A, B,rcond=rcond)[0] + + f_nat_grad=torch.from_numpy(f_nat_grad) + + f_nat_grad = check_device(f_nat_grad) + + return f_nat_grad + + def step(self, closure=None) -> torch.Tensor: """ It runs ONE step on the natural gradient descent. @@ -109,17 +152,58 @@ def step(self, closure=None) -> torch.Tensor: bound_res = bval-true_bval - # assemble gramian - G_int = self.gram_factory(int_res) - G_bdry = self.gram_factory(bound_res) - G = G_int + G_bdry + ## assemble gramian + #G_int = self.gram_factory(int_res.reshape(-1)) + #G_bdry = self.gram_factory(bound_res.reshape(-1)) + #G = G_int + G_bdry + + ## Marquardt-Levenberg + #Id = torch.eye(len(G)) + #G = torch.min(torch.tensor([loss, 0.0])) * Id + G - # Marquardt-Levenberg - Id = torch.eye(len(G)) - G = torch.min(torch.tensor([loss, 0.0])) * Id + G + # compute natural gradient - f_nat_grad = self.torch_cuda_lstsq(G, f_grads) + if not self.cuda_out_of_memory_flag: + try: + if self.cuda_empty_once_for_test: + #print('Initial GPU check') + torch.cuda.empty_cache() + self.cuda_empty_once_for_test=False + + # assemble gramian + + #print('NGD GPU step') + + G_int = self.gram_factory(int_res.reshape(-1)) + G_bdry = self.gram_factory(bound_res.reshape(-1)) + G = G_int + G_bdry + + # Marquardt-Levenberg + Id = torch.eye(len(G)) + G = torch.min(torch.tensor([loss, 0.0])) * Id + G + + f_nat_grad = self.torch_cuda_lstsq(G, f_grads) + except torch.OutOfMemoryError: + print('[Warning] Least square returned CUDA out of memory error, CPU and RAM are used, which is significantly slower') + self.cuda_out_of_memory_flag=True + + G_int = self.gram_factory_cpu(int_res.reshape(-1).cpu()) + G_bdry = self.gram_factory_cpu(bound_res.reshape(-1).cpu()) + G = G_int + G_bdry + + + f_nat_grad = self.numpy_lstsq(G, f_grads) + else: + + + #print('NGD CPU step') + + G_int = self.gram_factory_cpu(int_res.reshape(-1).cpu()) + G_bdry = self.gram_factory_cpu(bound_res.reshape(-1).cpu()) + G = G_int + G_bdry + + f_nat_grad = self.numpy_lstsq(G, f_grads) # one step of NGD self.grid_line_search_update(loss_function, f_nat_grad)