Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed Jan 9, 2025
1 parent 6c4b3b2 commit 5ae2b5c
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 108 deletions.
169 changes: 61 additions & 108 deletions neurolib/control/optimal_control/oc_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,15 @@
import copy
from neurolib.models.jax.wc import WCModel
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise

import logging
from neurolib.optimize.loss_functions import (
kuramoto_loss,
cross_correlation_loss,
variance_loss,
osc_fourier_loss,
sync_fourier_loss,
)
from neurolib.control.optimal_control.oc import getdefaultweights

# TODO: introduce for all models, not just WC
wc_default_control_params = ["exc_ext", "inh_ext"]
wc_default_target_params = ["exc", "inh"]


def hilbert_jax(signal, axis=-1):

n = signal.shape[axis]
h = jnp.zeros(n)
h = h.at[0].set(1)

if n % 2 == 0:
h = h.at[1 : n // 2].set(2)
h = h.at[n // 2].set(1)
else:
h = h.at[1 : (n + 1) // 2].set(2)

h = jnp.expand_dims(h, tuple(i for i in range(signal.ndim) if i != axis))
h = jnp.broadcast_to(h, signal.shape)

fft_signal = jnp.fft.fft(signal, axis=axis)
analytic_fft = fft_signal * h

analytic_signal = jnp.fft.ifft(analytic_fft)
return analytic_signal


class Optimize:
def __init__(
Expand All @@ -43,8 +22,8 @@ def __init__(
loss_function,
param_names,
target_param_names,
target=None,
init_params=None,
regularization_function=lambda _: 0.0,
optimizer=optax.adabelief(1e-3),
):
assert isinstance(param_names, (list, tuple)) and len(param_names) > 0
Expand All @@ -54,7 +33,7 @@ def __init__(

self.model = copy.deepcopy(model)
self.loss_function = loss_function
self.target = target
self.regularization_function = regularization_function
self.optimizer = optimizer
self.param_names = param_names
self.target_param_names = target_param_names
Expand All @@ -70,7 +49,7 @@ def __init__(
self.params = dict(zip(param_names, [self.args[p] for p in param_names]))
self.opt_state = self.optimizer.init(self.params)

compute_loss = lambda params: self.loss_function(params, self.get_output(params))
compute_loss = lambda params: self.loss_function(self.get_output(params)) + self.regularization_function(params)
self.compute_loss = jax.jit(compute_loss)
self.compute_gradient = jax.jit(jax.grad(self.compute_loss))

Expand All @@ -93,15 +72,7 @@ def get_output(self, params):
simulation_results = self.simulate(params)
return jnp.stack([simulation_results[tp][:, self.startind :] for tp in self.target_param_names])

def get_loss(self):
@jax.jit
def loss(params):
output = self.get_output(params)
return self.loss_function(params, output)

return loss

def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
def optimize(self, n_max_iterations, output_every_nth=None):
loss = self.compute_loss(self.control)
print(f"loss in iteration 0: %s" % (loss))
if len(self.cost_history) == 0: # add only if control model has not yet been optimized
Expand All @@ -121,105 +92,87 @@ def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
print(f"Final loss : %s" % (loss))


class OcWc(Optimize):
class Oc(Optimize):
"""
Convenience class for optimal control. The cost functional is constructed as a weighted sum of accuracy and control strength costs. Requires optimization parameters to be of shape (N, T).
"""

supported_cost_parameters = [
"w_p",
"w_cc",
"w_var",
"w_f_osc",
"w_f_sync",
"w_ko",
"w_2",
"w_1D",
]

def __init__(
self,
model,
target=None,
target_timeseries=None,
target_frequency=None,
optimizer=optax.adabelief(1e-3),
control_param_names=wc_default_control_params,
target_param_names=wc_default_target_params,
control_param_names=["exc_ext", "inh_ext"],
target_param_names=["exc", "inh"],
weights=None,
):
super().__init__(
model,
self.compute_total_cost,
self.accuracy_cost,
control_param_names,
target_param_names,
target=target,
init_params=None,
optimizer=optimizer,
regularization_function=self.control_strength_cost,
)
self.target_timeseries = target_timeseries
self.target_frequency = target_frequency
self.control = self.params
self.weights = getdefaultweights()
if weights is None:
self.weights = getdefaultweights()

def compute_total_cost(self, control, output):
def accuracy_cost(self, output):
"""
Compute the total cost as the sum of accuracy cost and control strength cost.
Parameters:
control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
Returns:
float: The total cost.
Args:
output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
"""
accuracy_cost = self.accuracy_cost(output)
control_arr = jnp.array(list(control.values()))
control_strength_cost = self.control_strength_cost(control_arr)
return accuracy_cost + control_strength_cost

# TODO: move cost functions outside
def accuracy_cost(self, output):
accuracy_cost = 0.0
if self.weights["w_p"] != 0.0:
accuracy_cost += self.weights["w_p"] * 0.5 * self.model.params.dt * jnp.sum((output - self.target) ** 2)
accuracy_cost += self.weights["w_p"] * self.precision_cost(output)
if self.weights["w_cc"] != 0.0:
accuracy_cost += self.weights["w_cc"] * self.compute_cc_cost(output)
accuracy_cost += self.weights["w_cc"] * cross_correlation_loss(output, self.model.params.dt)
if self.weights["w_var"] != 0.0:
accuracy_cost += self.weights["w_var"] * self.compute_var_cost(output)
accuracy_cost += self.weights["w_var"] * variance_loss(output)
if self.weights["w_f_osc"] != 0.0:
accuracy_cost += self.weights["w_f_osc"] * self.compute_osc_fourier_cost(output)
accuracy_cost += self.weights["w_f_osc"] * osc_fourier_loss(
output, self.target_frequency, self.model.params.dt
)
if self.weights["w_f_sync"] != 0.0:
accuracy_cost += self.weights["w_f_sync"] * self.compute_sync_fourier_cost(output)
accuracy_cost += self.weights["w_f_sync"] * sync_fourier_loss(
output, self.target_frequency, self.model.params.dt
)
if self.weights["w_ko"] != 0.0:
accuracy_cost += self.weights["w_ko"] * self.compute_kuramoto_cost(output)
accuracy_cost += self.weights["w_ko"] * kuramoto_loss(output)
return accuracy_cost

def precision_cost(self, output):
return 0.5 * self.model.params.dt * jnp.sum((output - self.target_timeseries) ** 2)

def control_strength_cost(self, control):
"""
Args:
control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
"""
control_arr = jnp.array(list(control.values()))
control_strength_cost = 0.0
if self.weights["w_2"] != 0.0:
control_strength_cost += self.weights["w_2"] * 0.5 * self.model.params.dt * jnp.sum(control**2)
control_strength_cost += self.weights["w_2"] * 0.5 * self.model.params.dt * jnp.sum(control_arr**2)
if self.weights["w_1D"] != 0.0:
control_strength_cost += self.weights["w_1D"] * self.compute_ds_cost(control)
control_strength_cost += self.weights["w_1D"] * self.compute_ds_cost(control_arr)
return control_strength_cost

def compute_ds_cost(self, control):
eps = 1e-6 # avoid grad(sqrt(0.0))
return jnp.sum(jnp.sqrt(jnp.sum(control**2, axis=2) * self.model.params.dt + eps))

def compute_cc_cost(self, output):
xmean = jnp.mean(output, axis=2, keepdims=True)
xstd = jnp.std(output, axis=2, keepdims=True)

xvec = (output - xmean) / xstd

costmat = jnp.einsum("vnt,vkt->vnkt", xvec, xvec)
diag = jnp.einsum("vnt,vnt->vt", xvec, xvec)
cost = jnp.sum(jnp.sum(costmat, axis=(1, 2)) - diag) * self.model.params.dt / 2.0
cost *= -2.0 / (self.model.params.N * (self.model.params.N - 1) * self.T * self.model.params.dt)
return cost

def compute_var_cost(self, output):
return jnp.var(output, axis=(0, 1)).mean()

def get_fourier_component(self, data, target_period):
fourier_series = jnp.abs(jnp.fft.fft(data)[: len(data) // 2])
freqs = jnp.fft.fftfreq(data.size, d=self.model.params.dt)[: len(data) // 2]
return fourier_series[jnp.argmin(jnp.abs(freqs - 1.0 / target_period))]

def compute_osc_fourier_cost(self, output):
cost = 0.0
for n in range(output.shape[1]):
for v in range(output.shape[0]):
cost -= self.get_fourier_component(output[v, n], self.target) ** 2
return cost / (output.shape[2] * self.model.params.dt) ** 2

def compute_sync_fourier_cost(self, output):
cost = 0.0
for v in range(output.shape[0]):
cost -= self.get_fourier_component(jnp.sum(output[v], axis=0), self.target) ** 2
return cost / (output.shape[2] * self.model.params.dt) ** 2

def compute_kuramoto_cost(self, output):
phase = jnp.angle(hilbert_jax(output, axis=2))
return -jnp.mean(jnp.abs(jnp.mean(jnp.exp(complex(0, 1) * phase), axis=1)))
112 changes: 112 additions & 0 deletions neurolib/optimize/loss_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import jax
import jax.numpy as jnp


def variance_loss(output):
"""
Args:
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
where N is number of nodes and T is number of timepoints
Returns:
float: Variance over time, averaged across output variables and nodes
"""
return jnp.var(output, axis=(0, 1)).mean()


def cross_correlation_loss(output, dt=1.0):
"""
Args:
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
where N is number of nodes and T is number of timepoints
dt (float): Time step
Returns:
float: Negative cross-correlation
"""
_, N, T = output.shape
xmean = jnp.mean(output, axis=2, keepdims=True)
xstd = jnp.std(output, axis=2, keepdims=True)

xvec = (output - xmean) / xstd

lossmat = jnp.einsum("vnt,vkt->vnkt", xvec, xvec)
diag = jnp.einsum("vnt,vnt->vt", xvec, xvec)
loss = jnp.sum(jnp.sum(lossmat, axis=(1, 2)) - diag) * dt / 2.0
loss *= -2.0 / (N * (N - 1) * T * dt)
return loss


def hilbert(signal, axis=-1):
n = signal.shape[axis]
h = jnp.zeros(n)
h = h.at[0].set(1)

if n % 2 == 0:
h = h.at[1 : n // 2].set(2)
h = h.at[n // 2].set(1)
else:
h = h.at[1 : (n + 1) // 2].set(2)

h = jnp.expand_dims(h, tuple(i for i in range(signal.ndim) if i != axis))
h = jnp.broadcast_to(h, signal.shape)

fft_signal = jnp.fft.fft(signal, axis=axis)
analytic_fft = fft_signal * h

analytic_signal = jnp.fft.ifft(analytic_fft)
return analytic_signal


def kuramoto_loss(output):
"""
Args:
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
where N is number of nodes and T is number of timepoints
Returns:
float: Negative Kuramoto order parameter averaged over output variables
"""
phase = jnp.angle(hilbert(output, axis=2))
return -jnp.mean(jnp.real(jnp.mean(jnp.exp(1j * phase), axis=1)))


def get_fourier_component(data, target_frequency, dt=1.0):
fourier_series = jnp.abs(jnp.fft.fft(data)[: len(data) // 2])
freqs = jnp.fft.fftfreq(data.size, d=dt)[: len(data) // 2]
return fourier_series[jnp.argmin(jnp.abs(freqs - target_frequency))]


def osc_fourier_loss(output, target_frequency, dt=1.0):
"""
Args:
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
where N is number of nodes and T is number of timepoints
target_frequency (float): Frequency to optimize for
dt (float): Time step
Returns:
float: Negative synchronization of output nodes at target frequency, irrespective of phase
"""
loss = 0.0
for n in range(output.shape[1]):
for v in range(output.shape[0]):
loss -= get_fourier_component(output[v, n], target_frequency) ** 2
return loss / (output.shape[2] * dt) ** 2


def sync_fourier_loss(output, target_frequency, dt=1.0):
"""
Args:
output (jax.numpy.ndarray): Time series data with shape (n_output_vars, N, T)
where N is number of nodes and T is number of timepoints
target_frequency (float): Frequency to optimize for
dt (float): Time step
Returns:
float: Negative synchronization of output nodes at target frequency, considering phase
"""
loss = 0.0
for v in range(output.shape[0]):
loss -= get_fourier_component(jnp.sum(output[v], axis=0), target_frequency) ** 2
return loss / (output.shape[2] * dt) ** 2

0 comments on commit 5ae2b5c

Please sign in to comment.