Skip to content

Commit

Permalink
Lotka-Volterra example workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperSashka committed Aug 23, 2024
1 parent 312f8ef commit 382111e
Showing 1 changed file with 74 additions and 24 deletions.
98 changes: 74 additions & 24 deletions examples/example_Lotka_Volterra_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,68 @@
x0 = 4.
y0 = 2.
t0 = 0.
tmax = 1.
tmax = 1


from copy import deepcopy


# Define the model
class MultiOutputModel(torch.nn.Module):
def __init__(self):
super(MultiOutputModel, self).__init__()

self.width_out=[2]

# Shared layers (base network)
self.shared_fc1 = torch.nn.Linear(1, 64) # Input size of 1 (for t)
self.shared_fc2 = torch.nn.Linear(64, 32)

# Output head for Process 1
self.process1_fc = torch.nn.Linear(32, 1)

# Output head for Process 2
self.process2_fc = torch.nn.Linear(32, 1)

def forward(self, t):
# Shared layers forward pass
x = torch.tanh(self.shared_fc1(t))
x = torch.tanh(self.shared_fc2(x))

# Process 1 output head
process1_out = self.process1_fc(x)

# Process 2 output head
process2_out = self.process2_fc(x)

out=torch.cat((process1_out, process2_out), dim=1)

return out

# Initialize the model
#model =


def Lotka_experiment(grid_res, CACHE):
exp_dict_list = []

exp_dict_list = []
solver_device('gpu')

domain = Domain()
domain.variable('t', [0, tmax], grid_res)
#net = torch.nn.Sequential(
# torch.nn.Linear(1, 32),
# torch.nn.Tanh(),
# torch.nn.Linear(32, 32),
# torch.nn.Tanh(),
# torch.nn.Linear(32, 2)
#)

h = 0.0001
net=MultiOutputModel()




domain = Domain()
domain.variable('t', [0, 1], grid_res)

boundaries = Conditions()
#initial conditions
Expand Down Expand Up @@ -104,15 +155,9 @@ def Lotka_experiment(grid_res, CACHE):
equation.add(eq1)
equation.add(eq2)

net = torch.nn.Sequential(
torch.nn.Linear(1, 32),
torch.nn.Tanh(),
torch.nn.Linear(32, 32),
torch.nn.Tanh(),
torch.nn.Linear(32, 2)
)

model = Model(net, domain, equation, boundaries,batch_size=16)

model = Model(net, domain, equation, boundaries)

model.compile("autograd", lambda_operator=1, lambda_bound=100)

Expand All @@ -121,17 +166,19 @@ def Lotka_experiment(grid_res, CACHE):
start = time.time()

cb_es = early_stopping.EarlyStopping(eps=1e-6,
loss_window=100,
no_improvement_patience=500,
patience=3,
info_string_every=100,
randomize_parameter=1e-5)
loss_window=1000,
no_improvement_patience=500,
patience=3,
info_string_every=100,
randomize_parameter=1e-5)

cb_plots = plot.Plots(save_every=1000, print_every=None, img_dir=img_dir)

#cb_cache = cache.Cache(cache_verbose=True, model_randomize_parameter=1e-5)

optimizer = Optimizer('Adam', {'lr': 1e-4})

model.train(optimizer, 5e6, save_model=True, callbacks=[cb_es, cb_plots])
model.train(optimizer, 2e5, save_model=True, callbacks=[cb_es, cb_plots])

end = time.time()

Expand All @@ -150,7 +197,7 @@ def deriv(X, t, alpha, beta, delta, gamma):
doty = y * (-delta + gamma * x)
return np.array([dotx, doty])

t = np.linspace(0., tmax, grid_res+1)
t = np.linspace(0, 1, grid_res+1)

X0 = [x0, y0]
res = integrate.odeint(deriv, X0, t, args = (alpha, beta, delta, gamma))
Expand All @@ -168,20 +215,23 @@ def deriv(X, t, alpha, beta, delta, gamma):
print('Time taken {}= {}'.format(grid_res, end - start))
print('RMSE {}= {}'.format(grid_res, error_rmse))

t = domain.variable_dict['t']
#t = domain.variable_dict['t']
grid = domain.build('NN')

t = np.linspace(0, 1, grid_res+1)

plt.figure()
plt.grid()
plt.title("odeint and NN methods comparing")
plt.plot(t.cpu(), u_exact[:,0].detach().numpy().reshape(-1), '+', label = 'preys_odeint')
plt.plot(t.cpu(), u_exact[:,1].detach().numpy().reshape(-1), '*', label = "predators_odeint")
plt.plot(t, u_exact[:,0].detach().numpy().reshape(-1), '+', label = 'preys_odeint')
plt.plot(t, u_exact[:,1].detach().numpy().reshape(-1), '*', label = "predators_odeint")
plt.plot(grid.cpu(), net(grid.cpu())[:,0].detach().numpy().reshape(-1), label='preys_NN')
plt.plot(grid.cpu(), net(grid.cpu())[:,1].detach().numpy().reshape(-1), label='predators_NN')
plt.xlabel('Time t, [days]')
plt.ylabel('Population')
plt.legend(loc='upper right')
plt.show()
plt.savefig(os.path.join(img_dir,'compare_{}_{}.png'.format(grid_res,part)))


return exp_dict_list

Expand Down

0 comments on commit 382111e

Please sign in to comment.