Skip to content

Commit

Permalink
batch size fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperSashka committed Aug 26, 2024
1 parent 382111e commit a65fbab
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
20 changes: 11 additions & 9 deletions examples/example_Lotka_Volterra_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ def __init__(self):
self.width_out=[2]

# Shared layers (base network)
self.shared_fc1 = torch.nn.Linear(1, 64) # Input size of 1 (for t)
self.shared_fc2 = torch.nn.Linear(64, 32)

self.shared_fc1 = torch.nn.Linear(1, 100) # Input size of 1 (for t)
self.shared_fc2 = torch.nn.Linear(100, 100)
self.shared_fc3 = torch.nn.Linear(100, 100)
self.shared_fc4 = torch.nn.Linear(100, 100)
# Output head for Process 1
self.process1_fc = torch.nn.Linear(32, 1)
self.process1_fc = torch.nn.Linear(100, 1)

# Output head for Process 2
self.process2_fc = torch.nn.Linear(32, 1)
self.process2_fc = torch.nn.Linear(100, 1)

def forward(self, t):
# Shared layers forward pass
x = torch.tanh(self.shared_fc1(t))
x = torch.tanh(self.shared_fc2(x))

x = torch.tanh(self.shared_fc3(x))
x = torch.tanh(self.shared_fc4(x))
# Process 1 output head
process1_out = self.process1_fc(x)

Expand Down Expand Up @@ -157,7 +159,7 @@ def Lotka_experiment(grid_res, CACHE):



model = Model(net, domain, equation, boundaries)
model = Model(net, domain, equation, boundaries, batch_size=64)

model.compile("autograd", lambda_operator=1, lambda_bound=100)

Expand Down Expand Up @@ -230,7 +232,7 @@ def deriv(X, t, alpha, beta, delta, gamma):
plt.xlabel('Time t, [days]')
plt.ylabel('Population')
plt.legend(loc='upper right')
plt.savefig(os.path.join(img_dir,'compare_{}_{}.png'.format(grid_res,part)))
plt.savefig(os.path.join(img_dir,'compare_{}.png'.format(grid_res)))


return exp_dict_list
Expand All @@ -241,7 +243,7 @@ def deriv(X, t, alpha, beta, delta, gamma):

CACHE=False

for grid_res in range(60,101,10):
for grid_res in range(60,1001,100):
for _ in range(nruns):
exp_dict_list.append(Lotka_experiment(grid_res,CACHE))

Expand Down
3 changes: 3 additions & 0 deletions tedeous/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def compile(
self.equation_cls = Operator_bcond_preproc(grid, operator, bconds, h=h, inner_order=inner_order,
boundary_order=boundary_order).set_strategy(mode)

if len(grid)<self.batch_size:
self.batch_size=None

self.solution_cls = Solution(grid, self.equation_cls, self.net, mode, weak_form,
lambda_operator, lambda_bound, tol, derivative_points,
batch_size=self.batch_size)
Expand Down

0 comments on commit a65fbab

Please sign in to comment.