-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
311 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import torch | ||
import os | ||
import sys | ||
import scipy | ||
|
||
|
||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) | ||
|
||
from tedeous.data import Domain, Conditions, Equation | ||
from tedeous.model import Model | ||
from tedeous.callbacks import early_stopping, plot | ||
from tedeous.optimizers.optimizer import Optimizer | ||
from tedeous.device import solver_device, check_device | ||
|
||
|
||
def exact_solution(grid): | ||
grid = grid.to('cpu').detach() | ||
test_data = scipy.io.loadmat(os.path.abspath( | ||
os.path.join(os.path.dirname( __file__ ), 'wolfram_sln/buckley_exact.mat'))) | ||
u = torch.from_numpy(test_data['u']).reshape(-1, 1) | ||
|
||
# grid_test | ||
x = torch.from_numpy(test_data['x']).reshape(-1, 1) | ||
t = torch.from_numpy(test_data['t']).reshape(-1, 1) | ||
|
||
grid_data = torch.cat((x, t), dim=1) | ||
|
||
exact = scipy.interpolate.griddata(grid_data, u, grid, method='nearest').reshape(-1) | ||
|
||
return torch.from_numpy(exact) | ||
|
||
|
||
solver_device('cuda') | ||
|
||
m = 0.2 | ||
L = 1 | ||
Q = -0.1 | ||
Sq = 1 | ||
mu_w = 0.89e-3 | ||
mu_o = 4.62e-3 | ||
Swi0 = 0. | ||
Sk = 1. | ||
t_end = 1. | ||
|
||
|
||
def experiment(grid_res, mode): | ||
|
||
domain = Domain() | ||
|
||
domain.variable('x', [0, 1], grid_res, dtype='float32') | ||
domain.variable('t', [0, 1], grid_res, dtype='float32') | ||
|
||
boundaries = Conditions() | ||
|
||
##initial cond | ||
boundaries.dirichlet({'x': [0, 1], 't': 0}, value=Swi0) | ||
|
||
##boundary cond | ||
boundaries.dirichlet({'x': 0, 't': [0, 1]}, value=Sk) | ||
|
||
net = torch.nn.Sequential( | ||
torch.nn.Linear(2, 20), | ||
torch.nn.Tanh(), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.Tanh(), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.Tanh(), | ||
torch.nn.Linear(20, 1) | ||
) | ||
|
||
def k_oil(x): | ||
return (1-net(x))**2 | ||
|
||
def k_water(x): | ||
return (net(x))**2 | ||
|
||
def dk_water(x): | ||
return 2*net(x) | ||
|
||
def dk_oil(x): | ||
return -2*(1-net(x)) | ||
|
||
def df(x): | ||
return (dk_water(x)*(k_water(x)+mu_w/mu_o*k_oil(x))- | ||
k_water(x)*(dk_water(x)+mu_w/mu_o*dk_oil(x)))/(k_water(x)+mu_w/mu_o*k_oil(x))**2 | ||
|
||
def coef_model(x): | ||
return -Q/Sq*df(x) | ||
|
||
equation = Equation() | ||
|
||
buckley_eq = { | ||
'm*ds/dt**1': | ||
{ | ||
'coeff': m, | ||
'ds/dt': [1], | ||
'pow': 1 | ||
}, | ||
'-Q/Sq*df*ds/dx**1': | ||
{ | ||
'coeff': coef_model, | ||
'ds/dx': [0], | ||
'pow':1 | ||
} | ||
} | ||
|
||
equation.add(buckley_eq) | ||
|
||
model = Model(net, domain, equation, boundaries) | ||
|
||
model.compile(mode, lambda_operator=1, lambda_bound=10) | ||
|
||
img_dir=os.path.join(os.path.dirname( __file__ ), 'Buckley_img') | ||
|
||
|
||
cb_es = early_stopping.EarlyStopping(eps=1e-6, | ||
loss_window=100, | ||
no_improvement_patience=500, | ||
patience=5, | ||
abs_loss=1e-5, | ||
randomize_parameter=1e-5, | ||
info_string_every=1000) | ||
|
||
cb_plots = plot.Plots(save_every=1000, print_every=None, img_dir=img_dir) | ||
|
||
# model.train(optimizer, 10000, save_model=False, callbacks=[cb_es, cb_plots]) | ||
|
||
grid = domain.build(mode) | ||
|
||
u_exact = exact_solution(grid).to('cuda') | ||
|
||
u_exact = check_device(u_exact).reshape(-1) | ||
|
||
u_pred = check_device(net(grid)).reshape(-1) | ||
|
||
error_rmse = torch.sqrt(torch.sum((u_exact - u_pred)**2)) / torch.sqrt(torch.sum(u_exact**2)) | ||
|
||
print('RMSE_adam= ', error_rmse.item()) | ||
|
||
################# | ||
|
||
optimizer = Optimizer('NGD', {'grid_steps_number': 20}) | ||
|
||
cb_plots = plot.Plots(save_every=100, print_every=None, img_dir=img_dir) | ||
|
||
model.train(optimizer, 3000, info_string_every=100, save_model=False, callbacks=[cb_plots]) | ||
|
||
u_pred = check_device(net(grid)).reshape(-1) | ||
|
||
error_rmse = torch.sqrt(torch.sum((u_exact - u_pred)**2)) / torch.sqrt(torch.sum(u_exact**2)) | ||
|
||
print('RMSE_pso= ', error_rmse.item()) | ||
|
||
return net | ||
|
||
for i in range(1): | ||
model = experiment(20, 'autograd') | ||
|
||
## After experiment, RMSE_adam ~ 0.23, RMSE_pso ~ 0.19 or less. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import torch | ||
from numpy.linalg import lstsq | ||
import numpy as np | ||
from torch.nn.utils import parameters_to_vector, vector_to_parameters | ||
from tedeous.utils import replace_none_by_zero | ||
|
||
|
||
class NGD(torch.optim.Optimizer): | ||
|
||
"""NGD implementation (https://arxiv.org/abs/2302.13163). | ||
""" | ||
|
||
def __init__(self, params, | ||
grid_steps_number: int = 30): | ||
"""The Natural Gradient Descent class. | ||
Args: | ||
grid_steps_number (int, optional): Grid steps number. Defaults to 30. | ||
""" | ||
defaults = {'grid_steps_number': grid_steps_number} | ||
super(NGD, self).__init__(params, defaults) | ||
self.params = self.param_groups[0]['params'] | ||
self.grid_steps_number = grid_steps_number | ||
self.grid_steps = torch.linspace(0, self.grid_steps_number, self.grid_steps_number + 1) | ||
self.steps = 0.5**self.grid_steps | ||
|
||
def grid_line_search_update(self, loss_function: callable, f_nat_grad: torch.Tensor) -> None: | ||
""" Update models paramters by natural gradient. | ||
Args: | ||
loss (callable): function to calculate loss. | ||
Returns: | ||
None. | ||
""" | ||
# function to update models paramters at each step | ||
def loss_at_step(step, loss_function: callable, f_nat_grad: torch.Tensor) -> torch.Tensor: | ||
params = parameters_to_vector(self.params) | ||
new_params = params - step * f_nat_grad | ||
vector_to_parameters(new_params, self.params) | ||
loss_val, _ = loss_function() | ||
vector_to_parameters(params, self.params) | ||
return loss_val | ||
|
||
losses = [] | ||
for step in self.steps: | ||
losses.append(loss_at_step(step, loss_function, f_nat_grad).reshape(1)) | ||
losses = torch.cat(losses) | ||
step_size = self.steps[torch.argmin(losses)] | ||
|
||
params = parameters_to_vector(self.params) | ||
new_params = params - step_size * f_nat_grad | ||
vector_to_parameters(new_params, self.params) | ||
|
||
def gram_factory(self, residuals: torch.Tensor) -> torch.Tensor: | ||
""" Make Gram matrice. | ||
Args: | ||
residuals (callable): PDE residual. | ||
Returns: | ||
torch.Tensor: Gram matrice. | ||
""" | ||
# Make Gram matrice. | ||
def jacobian() -> torch.Tensor: | ||
jac = [] | ||
for l in residuals: | ||
j = torch.autograd.grad(l, self.params, retain_graph=True, allow_unused=True) | ||
j = replace_none_by_zero(j) | ||
j = parameters_to_vector(j).reshape(1, -1) | ||
jac.append(j) | ||
return torch.cat(jac) | ||
|
||
J = jacobian() | ||
return 1.0 / len(residuals) * J.T @ J | ||
|
||
def step(self, closure=None) -> torch.Tensor: | ||
""" It runs ONE step on the natural gradient descent. | ||
Returns: | ||
torch.Tensor: loss value for NGD step. | ||
""" | ||
|
||
int_res, bval, true_bval, loss, loss_function = closure() | ||
grads = torch.autograd.grad(loss, self.params, retain_graph=True, allow_unused=True) | ||
grads = replace_none_by_zero(grads) | ||
f_grads = parameters_to_vector(grads) | ||
|
||
bound_res = bval-true_bval | ||
|
||
# assemble gramian | ||
G_int = self.gram_factory(int_res) | ||
G_bdry = self.gram_factory(bound_res) | ||
G = G_int + G_bdry | ||
|
||
# Marquardt-Levenberg | ||
Id = torch.eye(len(G)) | ||
G = torch.min(torch.tensor([loss, 0.0])) * Id + G | ||
|
||
# compute natural gradient | ||
G = np.array(G.detach().cpu().numpy(), dtype=np.float32) | ||
f_grads = np.array(f_grads.detach().cpu().numpy(), dtype=np.float32) | ||
f_nat_grad = lstsq(G, f_grads)[0] | ||
f_nat_grad = torch.from_numpy(np.array(f_nat_grad)).to(torch.float32).to('cuda') | ||
|
||
# one step of NGD | ||
self.grid_line_search_update(loss_function, f_nat_grad) | ||
self.param_groups[0]['params'] = self.params | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
__version__ = '0.4.3' | ||
__version__ = '0.4.4' | ||
|