Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
  • Loading branch information
lenasal committed Jan 9, 2025
2 parents 89af353 + 0a20382 commit 8f8b8ce
Showing 1 changed file with 88 additions and 37 deletions.
125 changes: 88 additions & 37 deletions neurolib/control/optimal_control/oc_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import copy
from neurolib.models.jax.wc import WCModel
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise
from neurolib.optimize.autodiff.wc_optimizer import args_names

import logging
from neurolib.control.optimal_control.oc import getdefaultweights

# TODO: introduce for all models, not just WC
wc_default_control_params = ["exc_ext", "inh_ext"]
wc_default_target_params = ["exc", "inh"]

Expand All @@ -35,78 +35,129 @@ def hilbert_jax(signal, axis=-1):
return analytic_signal


class OcWc:
class Optimize:
def __init__(
self,
model,
loss_function,
param_names,
target_param_names,
target=None,
optimizer=optax.adam(1e-3),
control_params=wc_default_control_params,
target_params=wc_default_target_params,
init_params=None,
optimizer=optax.adabelief(1e-3),
):
assert isinstance(control_params, (list, tuple)) and len(control_params) > 0
assert isinstance(target_params, (list, tuple)) and len(target_params) > 0
assert all([cp in wc_default_control_params for cp in control_params])
assert all([tp in wc_default_target_params for tp in target_params])
assert isinstance(param_names, (list, tuple)) and len(param_names) > 0
assert isinstance(target_param_names, (list, tuple)) and len(target_param_names) > 0
assert all([p in model.args_names for p in param_names])
assert all([tp in model.output_vars for tp in target_param_names])

self.model = copy.deepcopy(model)
self.loss_function = loss_function
self.target = target
self.optimizer = optimizer
self.control_params = control_params
self.target_params = target_params

self.weights = getdefaultweights()
self.param_names = param_names
self.target_param_names = target_param_names

args_values = timeIntegration_args(self.model.params)
self.args = dict(zip(args_names, args_values))
self.args = dict(zip(self.model.args_names, args_values))

self.loss = self.get_loss()
self.compute_gradient = jax.jit(jax.grad(self.loss))
self.T = len(self.args["t"])
self.startind = self.model.getMaxDelay()
self.control = jnp.zeros((len(control_params), self.model.params.N, self.T), dtype=float)
self.opt_state = self.optimizer.init(self.control)
if init_params is not None:
self.params = init_params
else:
self.params = dict(zip(param_names, [self.args[p] for p in param_names]))
self.opt_state = self.optimizer.init(self.params)

compute_loss = lambda params: self.loss_function(params, self.get_output(params))
self.compute_loss = jax.jit(compute_loss)
self.compute_gradient = jax.jit(jax.grad(self.compute_loss))

self.cost_history = []

def simulate(self, control):
# TODO: allow arbitrary model, not just WC
def simulate(self, params):
args_local = self.args.copy()
args_local.update(dict(zip(self.control_params, [c for c in control])))
return timeIntegration_elementwise(**args_local)

def get_output(self, control):
t, exc, inh, exc_ou, inh_ou = self.simulate(control)
if self.target_params == ["exc", "inh"]:
output = jnp.stack((exc, inh), axis=0)
elif self.target_params == ["exc"]:
output = exc[None, ...]
elif self.target_params == ["inh"]:
output = inh[None, ...]
return output[:, :, self.startind :]
args_local.update(params)
t, exc, inh, exc_ou, inh_ou = timeIntegration_elementwise(**args_local)
return {
"t": t,
"exc": exc,
"inh": inh,
"exc_ou": exc_ou,
"inh_ou": inh_ou,
}

def get_output(self, params):
simulation_results = self.simulate(params)
return jnp.stack([simulation_results[tp][:, self.startind :] for tp in self.target_param_names])

def get_loss(self):
@jax.jit
def loss(control):
output = self.get_output(control)
return self.compute_total_cost(control, output)
def loss(params):
output = self.get_output(params)
return self.loss_function(params, output)

return loss

def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
loss = self.compute_loss(self.control)
print(f"loss in iteration 0: %s" % (loss))
if len(self.cost_history) == 0: # add only if control model has not yet been optimized
self.cost_history.append(loss)

for i in range(1, n_max_iterations + 1):
self.gradient = self.compute_gradient(self.control)
updates, self.opt_state = self.optimizer.update(self.gradient, self.opt_state)
self.control = optax.apply_updates(self.control, updates)

if output_every_nth is not None and i % output_every_nth == 0:
loss = self.compute_loss(self.control)
self.cost_history.append(loss)
print(f"loss in iteration %s: %s" % (i, loss))

loss = self.compute_loss(self.control)
print(f"Final loss : %s" % (loss))


class OcWc(Optimize):
def __init__(
self,
model,
target=None,
optimizer=optax.adabelief(1e-3),
control_param_names=wc_default_control_params,
target_param_names=wc_default_target_params,
):
super().__init__(
model,
self.compute_total_cost,
control_param_names,
target_param_names,
target=target,
init_params=None,
optimizer=optimizer,
)
self.control = self.params
self.weights = getdefaultweights()

def compute_total_cost(self, control, output):
"""
Compute the total cost as the sum of accuracy cost and control strength cost.
Parameters:
control (jax.numpy.ndarray): Control input array of shape ((len(control_params)), N, T).
output (jax.numpy.ndarray): Simulation output of shape ((len(target_params)), N, T).
control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
Returns:
float: The total cost.
"""
accuracy_cost = self.accuracy_cost(output)
control_strength_cost = self.control_strength_cost(control)
control_arr = jnp.array(list(control.values()))
control_strength_cost = self.control_strength_cost(control_arr)
return accuracy_cost + control_strength_cost

# TODO: move cost functions outside
def accuracy_cost(self, output):
accuracy_cost = 0.0
if self.weights["w_p"] != 0.0:
Expand Down

0 comments on commit 8f8b8ce

Please sign in to comment.