diff --git a/examples/example_Lotka_Volterra_paper.py b/examples/example_Lotka_Volterra_paper.py index d87d51a1..c782f7ee 100644 --- a/examples/example_Lotka_Volterra_paper.py +++ b/examples/example_Lotka_Volterra_paper.py @@ -47,20 +47,22 @@ def __init__(self): self.width_out=[2] # Shared layers (base network) - self.shared_fc1 = torch.nn.Linear(1, 64) # Input size of 1 (for t) - self.shared_fc2 = torch.nn.Linear(64, 32) - + self.shared_fc1 = torch.nn.Linear(1, 100) # Input size of 1 (for t) + self.shared_fc2 = torch.nn.Linear(100, 100) + self.shared_fc3 = torch.nn.Linear(100, 100) + self.shared_fc4 = torch.nn.Linear(100, 100) # Output head for Process 1 - self.process1_fc = torch.nn.Linear(32, 1) + self.process1_fc = torch.nn.Linear(100, 1) # Output head for Process 2 - self.process2_fc = torch.nn.Linear(32, 1) + self.process2_fc = torch.nn.Linear(100, 1) def forward(self, t): # Shared layers forward pass x = torch.tanh(self.shared_fc1(t)) x = torch.tanh(self.shared_fc2(x)) - + x = torch.tanh(self.shared_fc3(x)) + x = torch.tanh(self.shared_fc4(x)) # Process 1 output head process1_out = self.process1_fc(x) @@ -157,7 +159,7 @@ def Lotka_experiment(grid_res, CACHE): - model = Model(net, domain, equation, boundaries) + model = Model(net, domain, equation, boundaries, batch_size=64) model.compile("autograd", lambda_operator=1, lambda_bound=100) @@ -230,7 +232,7 @@ def deriv(X, t, alpha, beta, delta, gamma): plt.xlabel('Time t, [days]') plt.ylabel('Population') plt.legend(loc='upper right') - plt.savefig(os.path.join(img_dir,'compare_{}_{}.png'.format(grid_res,part))) + plt.savefig(os.path.join(img_dir,'compare_{}.png'.format(grid_res))) return exp_dict_list @@ -241,7 +243,7 @@ def deriv(X, t, alpha, beta, delta, gamma): CACHE=False -for grid_res in range(60,101,10): +for grid_res in range(60,1001,100): for _ in range(nruns): exp_dict_list.append(Lotka_experiment(grid_res,CACHE)) diff --git a/tedeous/model.py b/tedeous/model.py index 30345f7f..53896d9c 100644 --- a/tedeous/model.py +++ b/tedeous/model.py @@ -93,6 +93,9 @@ def compile( self.equation_cls = Operator_bcond_preproc(grid, operator, bconds, h=h, inner_order=inner_order, boundary_order=boundary_order).set_strategy(mode) + if len(grid)