Skip to content

Commit

Permalink
More stable NGD
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperSashka committed Sep 4, 2024
1 parent 3701592 commit 42bc624
Show file tree
Hide file tree
Showing 5 changed files with 458 additions and 30 deletions.
36 changes: 19 additions & 17 deletions examples/example_Lotka_Volterra_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tedeous.callbacks import cache, early_stopping, plot
from tedeous.optimizers.optimizer import Optimizer
from tedeous.device import solver_device, check_device, device_type

from tedeous.models import Fourier_embedding

alpha = 20.
beta = 20.
Expand All @@ -47,22 +47,18 @@ def __init__(self):
self.width_out=[2]

# Shared layers (base network)
self.shared_fc1 = torch.nn.Linear(1, 100) # Input size of 1 (for t)
self.shared_fc2 = torch.nn.Linear(100, 100)
self.shared_fc3 = torch.nn.Linear(100, 100)
self.shared_fc4 = torch.nn.Linear(100, 100)
self.shared_fc1 = torch.nn.Linear(1, 64) # Input size of 1 (for t)
self.shared_fc2 = torch.nn.Linear(64, 32)
# Output head for Process 1
self.process1_fc = torch.nn.Linear(100, 1)
self.process1_fc = torch.nn.Linear(32, 1)

# Output head for Process 2
self.process2_fc = torch.nn.Linear(100, 1)
self.process2_fc = torch.nn.Linear(32, 1)

def forward(self, t):
# Shared layers forward pass
x = torch.tanh(self.shared_fc1(t))
x = torch.tanh(self.shared_fc2(x))
x = torch.tanh(self.shared_fc3(x))
x = torch.tanh(self.shared_fc4(x))
# Process 1 output head
process1_out = self.process1_fc(x)

Expand All @@ -82,13 +78,19 @@ def Lotka_experiment(grid_res, CACHE):
exp_dict_list = []
solver_device('gpu')

#net = torch.nn.Sequential(
# torch.nn.Linear(1, 32),
# torch.nn.Tanh(),
# torch.nn.Linear(32, 32),
# torch.nn.Tanh(),
# torch.nn.Linear(32, 2)
#)
FFL = Fourier_embedding(L=[5], M=[2])

out = FFL.out_features


net = torch.nn.Sequential(
FFL,
torch.nn.Linear(out, 32),
torch.nn.Tanh(),
torch.nn.Linear(32, 32),
torch.nn.Tanh(),
torch.nn.Linear(32, 2)
)

net=MultiOutputModel()

Expand Down Expand Up @@ -180,7 +182,7 @@ def Lotka_experiment(grid_res, CACHE):

optimizer = Optimizer('Adam', {'lr': 1e-4})

model.train(optimizer, 2e5, save_model=True, callbacks=[cb_es, cb_plots])
model.train(optimizer, 1e4, save_model=True, callbacks=[cb_es, cb_plots])

end = time.time()

Expand Down
Loading

0 comments on commit 42bc624

Please sign in to comment.