diff --git a/neuro_py/ensemble/decoding/lstm.py b/neuro_py/ensemble/decoding/lstm.py index 05ce188..c0753f4 100644 --- a/neuro_py/ensemble/decoding/lstm.py +++ b/neuro_py/ensemble/decoding/lstm.py @@ -1,3 +1,5 @@ +from typing import List, Dict, Tuple, Optional + import torch import torch.nn.functional as F import lightning as L @@ -6,43 +8,60 @@ class LSTM(L.LightningModule): - """Long Short-Term Memory (LSTM) model.""" - def __init__(self, in_dim=100, out_dim=2, hidden_dims=(400, 1, .0), use_bias=True, args={}): - """ - Constructs a LSTM model - - Parameters - ---------- - in_dim : int - Dimensionality of input data - out_dim : int - Dimensionality of output data - hidden_dims : List - Architectural parameters of the model - (hidden_size, num_layers, dropout) - use_bias : bool - Whether to use bias or not in the final linear layer - """ + """ + Long Short-Term Memory (LSTM) model. + + This class implements an LSTM model using PyTorch Lightning. + + Parameters + ---------- + in_dim : int, optional + Dimensionality of input data, by default 100 + out_dim : int, optional + Dimensionality of output data, by default 2 + hidden_dims : Tuple[int, int, float], optional + Architectural parameters of the model (hidden_size, num_layers, dropout), + by default (400, 1, 0.0) + use_bias : bool, optional + Whether to use bias or not in the final linear layer, by default True + args : Dict, optional + Additional arguments for model configuration, by default {} + + Attributes + ---------- + lstm : nn.LSTM + LSTM layer + fc : nn.Linear + Fully connected layer + hidden_state : Optional[torch.Tensor] + Hidden state of the LSTM + cell_state : Optional[torch.Tensor] + Cell state of the LSTM + """ + def __init__(self, in_dim: int = 100, out_dim: int = 2, + hidden_dims: Tuple[int, int, float] = (400, 1, 0.0), + use_bias: bool = True, args: Dict = {}): super().__init__() self.save_hyperparameters() self.in_dim = in_dim self.out_dim = out_dim if len(hidden_dims) != 3: raise ValueError('`hidden_dims` should be of size 3') - hidden_size, nlayers, dropout = hidden_dims - self.nlayers = nlayers - self.hidden_size = hidden_size - self.dropout = dropout + self.hidden_size, self.nlayers, self.dropout = hidden_dims self.args = args - # Add final layer to the number of classes - self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hidden_size, - num_layers=nlayers, batch_first=True, dropout=dropout, bidirectional=True) - self.fc = nn.Linear(in_features=2*hidden_size, out_features=out_dim, bias=use_bias) - self.hidden_state = None - self.cell_state = None + self.lstm = nn.LSTM(input_size=in_dim, hidden_size=self.hidden_size, + num_layers=self.nlayers, batch_first=True, + dropout=self.dropout, bidirectional=True) + self.fc = nn.Linear(in_features=2*self.hidden_size, out_features=out_dim, bias=use_bias) + self.hidden_state: Optional[torch.Tensor] = None + self.cell_state: Optional[torch.Tensor] = None - def init_params(m): + self._init_params() + + def _init_params(self) -> None: + """Initialize model parameters.""" + def init_params(m: nn.Module) -> None: if isinstance(m, nn.Linear): torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu') if m.bias is not None: @@ -51,35 +70,63 @@ def init_params(m): nn.init.uniform_(m.bias, -bound, bound) # LeCunn init init_params(self.fc) - def forward(self, x): - lstm_out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state)) - B, L, H = lstm_out.shape - # Shape: [batch_size x max_length x hidden_dim] + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the LSTM model. - # Select the activation of the last Hidden Layer - # lstm_out = lstm_out.view(B, L, 2, -1).sum(dim=2) - lstm_out = lstm_out[:,-1,:].contiguous() - - # Shape: [batch_size x hidden_dim] + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, sequence_length, input_dim) - # Fully connected layer + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, output_dim) + """ + lstm_out, (self.hidden_state, self.cell_state) = \ + self.lstm(x, (self.hidden_state, self.cell_state)) + lstm_out = lstm_out[:, -1, :].contiguous() out = self.fc(lstm_out) - if self.args['clf']: + if self.args.get('clf', False): out = F.log_softmax(out, dim=1) - return out - def init_hidden(self, batch_size): - ''' Initializes hidden state ''' - # Create two new tensors with sizes n_layers x batch_size x hidden_dim, - # initialized to zero, for hidden state and cell state of LSTM + def init_hidden(self, batch_size: int) -> None: + """ + Initialize hidden state and cell state. + + Parameters + ---------- + batch_size : int + Batch size for initialization + """ self.batch_size = batch_size - h0 = torch.zeros((2*self.nlayers,batch_size,self.hidden_size), requires_grad=False) - c0 = torch.zeros((2*self.nlayers,batch_size,self.hidden_size), requires_grad=False) + h0 = torch.zeros( + (2*self.nlayers, batch_size, self.hidden_size), + requires_grad=False + ) + c0 = torch.zeros( + (2*self.nlayers, batch_size, self.hidden_size), + requires_grad=False + ) self.hidden_state = h0 self.cell_state = c0 - def predict(self, x): + def predict(self, x: torch.Tensor) -> torch.Tensor: + """ + Make predictions using the LSTM model. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + torch.Tensor + Predicted output + """ self.hidden_state = self.hidden_state.to(x.device) self.cell_state = self.cell_state.to(x.device) preds = [] @@ -93,45 +140,110 @@ def predict(self, x): pred_loc = pred_loc[:batch_size-(i-x.shape[0])] preds.extend(pred_loc) out = torch.stack(preds) - if self.args['clf']: + if self.args.get('clf', False): out = F.log_softmax(out, dim=1) return out - def _step(self, batch, batch_idx) -> torch.Tensor: - xs, ys = batch # unpack the batch - outs = self(xs) # apply the model + def _step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Perform a single step (forward pass + loss calculation). + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ + xs, ys = batch + outs = self(xs) loss = self.args['criterion'](outs, ys) return loss - def training_step(self, batch, batch_idx) -> torch.Tensor: + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Lightning method for training step. + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('train_loss', loss) return loss - def on_after_backward(self): - # LSTM specific + def on_after_backward(self) -> None: + """Lightning method called after backpropagation.""" self.hidden_state.detach_() self.cell_state.detach_() - # self.hidden_state.data.fill_(.0) - # self.cell_state.data.fill_(.0) - def validation_step(self, batch, batch_idx) -> torch.Tensor: + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Lightning method for validation step. + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('val_loss', loss) return loss - def test_step(self, batch, batch_idx) -> torch.Tensor: + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Lightning method for test step. + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('test_loss', loss) return loss - def configure_optimizers(self): - args = self.args + def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Dict]]: + """ + Configure optimizers and learning rate schedulers. + + Returns + ------- + Tuple[List[torch.optim.Optimizer], List[Dict]] + Tuple containing a list of optimizers and a list of scheduler configurations + """ optimizer = torch.optim.AdamW( - self.parameters(), weight_decay=args['weight_decay']) + self.parameters(), weight_decay=self.args['weight_decay']) scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, max_lr=args['lr'], - epochs=args['epochs'], + optimizer, max_lr=self.args['lr'], + epochs=self.args['epochs'], steps_per_epoch=len( self.trainer._data_connector._train_dataloader_source.dataloader() ) diff --git a/neuro_py/ensemble/decoding/m2mlstm.py b/neuro_py/ensemble/decoding/m2mlstm.py index 9113c87..8c4aae0 100644 --- a/neuro_py/ensemble/decoding/m2mlstm.py +++ b/neuro_py/ensemble/decoding/m2mlstm.py @@ -1,3 +1,5 @@ +from typing import List, Tuple, Dict, Optional + import numpy as np import torch import torch.nn.functional as F @@ -7,43 +9,60 @@ class M2MLSTM(L.LightningModule): - """Many-to-Many Long Short-Term Memory (LSTM) model.""" - def __init__(self, in_dim=100, out_dim=2, hidden_dims=(400, 1, .0), use_bias=True, args={}): - """ - Constructs a Many-to-Many LSTM + """ + Many-to-Many Long Short-Term Memory (LSTM) model. - Parameters - ---------- - in_dim : int - Dimensionality of input data - out_dim : int - Number of output columns - hidden_dims : List - Architectural parameters of the model - (hidden_size, num_layers, dropout) - use_bias : bool - Whether to use bias or not in the final linear layer - """ + This class implements a Many-to-Many LSTM model using PyTorch Lightning. + + Parameters + ---------- + in_dim : int, optional + Dimensionality of input data, by default 100 + out_dim : int, optional + Number of output columns, by default 2 + hidden_dims : Tuple[int, int, float], optional + Architectural parameters of the model (hidden_size, num_layers, dropout), + by default (400, 1, 0.0) + use_bias : bool, optional + Whether to use bias or not in the final linear layer, by default True + args : Dict, optional + Additional arguments for model configuration, by default {} + + Attributes + ---------- + lstm : nn.LSTM + LSTM layer + fc : nn.Linear + Fully connected layer + hidden_state : Optional[torch.Tensor] + Hidden state of the LSTM + cell_state : Optional[torch.Tensor] + Cell state of the LSTM + """ + def __init__(self, in_dim: int = 100, out_dim: int = 2, + hidden_dims: Tuple[int, int, float] = (400, 1, 0.0), + use_bias: bool = True, args: Dict = {}): super().__init__() self.save_hyperparameters() self.in_dim = in_dim self.out_dim = out_dim if len(hidden_dims) != 3: raise ValueError('`hidden_dims` should be of size 3') - hidden_size, nlayers, dropout = hidden_dims - self.nlayers = nlayers - self.hidden_size = hidden_size - self.dropout = dropout + self.hidden_size, self.nlayers, self.dropout = hidden_dims self.args = args - # Add final layer to the number of classes - self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hidden_size, - num_layers=nlayers, batch_first=True, dropout=dropout, bidirectional=False) - self.fc = nn.Linear(in_features=hidden_size, out_features=out_dim, bias=use_bias) - self.hidden_state = None - self.cell_state = None + self.lstm = nn.LSTM(input_size=in_dim, hidden_size=self.hidden_size, + num_layers=self.nlayers, batch_first=True, + dropout=self.dropout, bidirectional=False) + self.fc = nn.Linear(in_features=self.hidden_size, out_features=out_dim, bias=use_bias) + self.hidden_state: Optional[torch.Tensor] = None + self.cell_state: Optional[torch.Tensor] = None + + self._init_params() - def init_params(m): + def _init_params(self) -> None: + """Initialize model parameters.""" + def init_params(m: nn.Module) -> None: if isinstance(m, nn.Linear): torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu') if m.bias is not None: @@ -52,78 +71,152 @@ def init_params(m): nn.init.uniform_(m.bias, -bound, bound) # LeCunn init init_params(self.fc) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the LSTM model. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, sequence_length, input_dim) + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, sequence_length, output_dim) + """ B, L, N = x.shape self.hidden_state = self.hidden_state.to(x.device) self.cell_state = self.cell_state.to(x.device) - self.hidden_state.data.fill_(.0) - self.cell_state.data.fill_(.0) + self.hidden_state.data.fill_(0.0) + self.cell_state.data.fill_(0.0) lstm_outs = [] for i in range(L): lstm_out, (self.hidden_state, self.cell_state) = \ self.lstm(x[:, i].unsqueeze(1), (self.hidden_state, self.cell_state)) - # Shape: [batch_size x max_length x hidden_dim] lstm_outs.append(lstm_out) lstm_outs = torch.stack(lstm_outs, dim=1) # B, L, N - # Select the activation of the last Hidden Layer - # lstm_outs = lstm_outs.contiguous() - # lstm_outs = lstm_outs.view(-1, lstm_outs.shape[2]) # B*L, N - - # Shape: [batch_size x hidden_dim] - - # Fully connected layer out = self.fc(lstm_outs) out = out.view(B, L, self.out_dim) - if self.args['clf']: + if self.args.get('clf', False): out = F.log_softmax(out, dim=-1) return out - def init_hidden(self, batch_size): - ''' Initializes hidden state ''' - # Create two new tensors with sizes n_layers x batch_size x hidden_dim, - # initialized to zero, for hidden state and cell state of LSTM + def init_hidden(self, batch_size: int) -> None: + """ + Initialize hidden state and cell state. + + Parameters + ---------- + batch_size : int + Batch size for initialization + """ self.batch_size = batch_size - h0 = torch.zeros((self.nlayers,batch_size,self.hidden_size), requires_grad=False) - c0 = torch.zeros((self.nlayers,batch_size,self.hidden_size), requires_grad=False) - self.hidden_state = h0 - self.cell_state = c0 - - def _step(self, batch, batch_idx) -> torch.Tensor: - xs, ys = batch # unpack the batch - B, L, N = xs.shape + self.hidden_state = torch.zeros((self.nlayers, batch_size, self.hidden_size), requires_grad=False) + self.cell_state = torch.zeros((self.nlayers, batch_size, self.hidden_size), requires_grad=False) + + def _step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Perform a single step (forward pass + loss calculation). + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ + xs, ys = batch outs = self(xs) loss = self.args['criterion'](outs, ys) return loss - def training_step(self, batch, batch_idx) -> torch.Tensor: + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Lightning method for training step. + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('train_loss', loss) return loss - def on_after_backward(self): - # LSTM specific + def on_after_backward(self) -> None: + """Lightning method called after backpropagation.""" self.hidden_state = self.hidden_state.detach() self.cell_state = self.cell_state.detach() - def validation_step(self, batch, batch_idx) -> torch.Tensor: + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Lightning method for validation step. + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('val_loss', loss) return loss - def test_step(self, batch, batch_idx) -> torch.Tensor: + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """ + Lightning method for test step. + + Parameters + ---------- + batch : Tuple[torch.Tensor, torch.Tensor] + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('test_loss', loss) return loss - def configure_optimizers(self): - args = self.args + def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Dict]]: + """ + Configure optimizers and learning rate schedulers. + + Returns + ------- + Tuple[List[torch.optim.Optimizer], List[Dict]] + Tuple containing a list of optimizers and a list of scheduler configurations + """ optimizer = torch.optim.AdamW( - self.parameters(), weight_decay=args['weight_decay']) + self.parameters(), weight_decay=self.args['weight_decay']) scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, max_lr=args['lr'], - epochs=args['epochs'], + optimizer, max_lr=self.args['lr'], + epochs=self.args['epochs'], total_steps=self.trainer.estimated_stepping_batches ) lr_scheduler = {'scheduler': scheduler, 'interval': 'step'} @@ -131,13 +224,51 @@ def configure_optimizers(self): class NSVDataset(torch.utils.data.Dataset): - def __init__(self, nsv, dv): + """ + Custom Dataset for neural state vector (binned spike train) data. + + Parameters + ---------- + nsv : List[np.ndarray] + List of trial-segmented neural state vector arrays + dv : List[np.ndarray] + List of trial-segmented behavioral state vector arrays + + Attributes + ---------- + nsv : List[np.ndarray] + List of trial-segmented neural state vector arrays as float32 + dv : List[np.ndarray] + List of trial-segmented behavioral state vector arrays as float32 + """ + def __init__(self, nsv: List[np.ndarray], dv: List[np.ndarray]): self.nsv = [i.astype(np.float32) for i in nsv] self.dv = [i.astype(np.float32) for i in dv] - def __len__(self): + def __len__(self) -> int: + """ + Get the length of the dataset. + + Returns + ------- + int + Number of samples in the dataset + """ return len(self.nsv) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Get a sample from the dataset. + + Parameters + ---------- + idx : int + Index of the sample + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Tuple containing NSV and DV arrays + """ nsv, dv = self.nsv[idx], self.dv[idx] return nsv, dv diff --git a/neuro_py/ensemble/decoding/mlp.py b/neuro_py/ensemble/decoding/mlp.py index 5c9765d..a88188f 100644 --- a/neuro_py/ensemble/decoding/mlp.py +++ b/neuro_py/ensemble/decoding/mlp.py @@ -1,3 +1,5 @@ +from typing import List, Union, Dict, Optional + import torch import lightning as L @@ -5,12 +7,56 @@ class MLP(L.LightningModule): - """MLPs in Pytorch of an arbitrary number of hidden - layers of potentially different sizes. """ - def __init__(self, in_dim=100, out_dim=2, hidden_dims=(), use_bias=True, args=None): + Multi-Layer Perceptron (MLP) in PyTorch with an arbitrary number of hidden layers. + + This class implements an MLP model using PyTorch Lightning, allowing for flexible + architecture with varying hidden layer sizes and dropout probabilities. + + Parameters + ---------- + in_dim : int, optional + Dimensionality of input data, by default 100 + out_dim : int, optional + Dimensionality of output data, by default 2 + hidden_dims : List[Union[int, float]], optional + List containing architectural parameters of the model. If an element is + an int, it represents a hidden layer of that size. If an element is a float, + it represents a dropout layer with that probability. By default () + use_bias : bool, optional + Whether to use bias in all linear layers, by default True + args : Optional[Dict], optional + Dictionary containing the hyperparameters of the model, by default None + + Attributes + ---------- + main : nn.Sequential + The main sequential container of the MLP layers + """ + def __init__(self, in_dim: int = 100, out_dim: int = 2, + hidden_dims: List[Union[int, float]] = (), + use_bias: bool = True, args: Optional[Dict] = None): + super().__init__() + self.save_hyperparameters() + self.in_dim = in_dim + self.out_dim = out_dim + self.args = args if args is not None else {} + activations = nn.CELU if self.args.get('activations') is None else self.args['activations'] + + layers = self._build_layers(in_dim, out_dim, hidden_dims, use_bias, activations) + self.main = nn.Sequential(*layers) + self._init_params() + + def _build_layers( + self, + in_dim: int, + out_dim: int, + hidden_dims: List[Union[int, float]], + use_bias: bool, + activations: nn.Module + ) -> List[nn.Module]: """ - Constructs a MultiLayerPerceptron + Build the layers of the MLP. Parameters ---------- @@ -18,50 +64,48 @@ def __init__(self, in_dim=100, out_dim=2, hidden_dims=(), use_bias=True, args=No Dimensionality of input data out_dim : int Dimensionality of output data - hidden_dims : List - List containing architectural parameters of the model. If element is - int, it is a hidden layer of that size. If element is float, it is a - dropout layer with that probability. + hidden_dims : List[Union[int, float]] + List of hidden layer sizes and dropout probabilities use_bias : bool - Whether to use bias or not in the all linear layers - args : dict - Dictionary containing the hyperparameters of the model - """ - super().__init__() - self.save_hyperparameters() - self.in_dim = in_dim - self.out_dim = out_dim - self.args = args - activations = nn.CELU if self.args['activations'] is None else self.args['activations'] + Whether to use bias in linear layers + activations : nn.Module + Activation function to use - # If we have no hidden layer, just initialize a linear model (e.g. in logistic regression) + Returns + ------- + List[nn.Module] + List of layers for the MLP + """ if len(hidden_dims) == 0: - layers = [nn.Linear(in_dim, out_dim, bias=use_bias)] - else: - # 'Actual' MLP with dimensions in_dim - num_hidden_layers*[hidden_dim] - out_dim - layers = [] - hidden_dims = [in_dim] + hidden_dims - - # Loop until before the last layer - for i, hidden_dim in enumerate(hidden_dims[:-1]): - if isinstance(hidden_dim, float): - continue - if isinstance(hidden_dims[i+1], float): - layers += [nn.Linear(hidden_dim, hidden_dims[i + 2], bias=use_bias), - nn.Dropout(p=hidden_dims[i+1]), - activations() if i < len(hidden_dims)-1 else nn.Tanh()] - else: - layers += [nn.Linear(hidden_dim, hidden_dims[i + 1], bias=use_bias), - activations() if i < len(hidden_dims)-1 else nn.Tanh()] - - # Add final layer to the number of classes - layers += [nn.Linear(hidden_dims[-1], out_dim, bias=use_bias)] - if args['clf']: - layers += [nn.LogSoftmax(dim=1)] - - self.main = nn.Sequential(*layers) - - def init_params(m): + return [nn.Linear(in_dim, out_dim, bias=use_bias)] + + layers = [] + hidden_dims = [in_dim] + hidden_dims + + for i, hidden_dim in enumerate(hidden_dims[:-1]): + if isinstance(hidden_dim, float): + continue + if isinstance(hidden_dims[i+1], float): + layers.extend([ + nn.Linear(hidden_dim, hidden_dims[i + 2], bias=use_bias), + nn.Dropout(p=hidden_dims[i+1]), + activations() if i < len(hidden_dims)-1 else nn.Tanh() + ]) + else: + layers.extend([ + nn.Linear(hidden_dim, hidden_dims[i + 1], bias=use_bias), + activations() if i < len(hidden_dims)-1 else nn.Tanh() + ]) + + layers.append(nn.Linear(hidden_dims[-1], out_dim, bias=use_bias)) + if self.args.get('clf', False): + layers.append(nn.LogSoftmax(dim=1)) + + return layers + + def _init_params(self) -> None: + """Initialize the parameters of the model.""" + def init_params(m: nn.Module) -> None: if isinstance(m, nn.Linear): torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu') if m.bias is not None: @@ -70,15 +114,15 @@ def init_params(m): nn.init.uniform_(m.bias, -bound, bound) # LeCunn init self.main.apply(init_params) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Defines the network structure and flow from input to output + Defines the network structure and flow from input to output. Parameters ---------- x : torch.Tensor Input data - + Returns ------- torch.Tensor @@ -86,43 +130,112 @@ def forward(self, x): """ return self.main(x) - def _step(self, batch, batch_idx): - xs, ys = batch # unpack the batch - outs = self(xs) # apply the model - loss = self.args['criterion'](outs, ys) # compute the (squared error) loss + def _step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Perform a single step (forward pass + loss calculation). + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ + xs, ys = batch + outs = self(xs) + loss = self.args['criterion'](outs, ys) return loss - def training_step(self, batch, batch_idx): + def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Lightning method for training step. + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('train_loss', loss) return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Lightning method for validation step. + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('val_loss', loss) return loss - def test_step(self, batch, batch_idx): + def test_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Lightning method for test step. + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('test_loss', loss) return loss def configure_optimizers(self): - args = self.args + """ + Configure optimizers and learning rate schedulers. + + Returns + ------- + tuple + Tuple containing a list of optimizers and a list of scheduler configurations + """ optimizer = torch.optim.AdamW( - self.parameters(), weight_decay=args['weight_decay'], betas=(0.9, 0.999), - amsgrad=True) - # scheduler = torch.optim.lr_scheduler.OneCycleLR( - # optimizer, max_lr=args['lr'], - # epochs=args['epochs'], steps_per_epoch=len(self.trainer._data_connector._train_dataloader_source.dataloader()) - # ) - # write a cycliclr scheduler + self.parameters(), + weight_decay=self.args['weight_decay'], + betas=(0.9, 0.999), + amsgrad=True + ) scheduler = torch.optim.lr_scheduler.CyclicLR( - optimizer, base_lr=args['base_lr'], max_lr=args['lr'], - step_size_up=self.args['scheduler_step_size_multiplier']*self.args['num_training_batches'], # assuming 1 batch_size, multiply if more: https://discuss.pytorch.org/t/step-size-for-cyclic-scheduler/69262/4 + optimizer, + base_lr=self.args['base_lr'], + max_lr=self.args['lr'], + step_size_up=self.args['scheduler_step_size_multiplier'] * self.args['num_training_batches'], cycle_momentum=False, - mode='triangular2', gamma=0.99994, - last_epoch=-1, verbose=False + mode='triangular2', + gamma=0.99994, + last_epoch=-1, + verbose=False ) lr_scheduler = {'scheduler': scheduler, 'interval': 'step'} return [optimizer], [lr_scheduler] diff --git a/neuro_py/ensemble/decoding/pipeline.py b/neuro_py/ensemble/decoding/pipeline.py index 658ac1b..b56073b 100644 --- a/neuro_py/ensemble/decoding/pipeline.py +++ b/neuro_py/ensemble/decoding/pipeline.py @@ -2,6 +2,10 @@ import os import random +from typing import List, Tuple, Dict, Optional, Any + +import sklearn.preprocessing + import numpy as np import pandas as pd import bottleneck as bn @@ -10,35 +14,38 @@ import torch import zlib +from numpy.typing import NDArray + from .mlp import MLP # noqa from .lstm import LSTM # noqa from .m2mlstm import M2MLSTM, NSVDataset # noqa from .transformer import NDT # noqa -def seed_worker(worker_id): +def seed_worker(worker_id: int) -> None: """ - DataLoader will reseed workers following randomness in - multi-process data loading algorithm. - - Args: - worker_id: integer - ID of subprocess to seed. 0 means that - the data will be loaded in the main process - Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details + Seed a worker with the given ID for reproducibility in data loading. + + Parameters + ---------- + worker_id : int + The ID of the worker to be seeded. + + Notes + ----- + This function is used to ensure reproducibility when using multi-process data loading. """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) - -def get_spikes_with_history(neural_data, bins_before, bins_after, bins_current=1): +def get_spikes_with_history(neural_data: np.ndarray, bins_before: int, bins_after: int, bins_current: int = 1) -> np.ndarray: """ Create the covariate matrix of neural activity. Parameters ---------- - neural_data : numpy.ndarray + neural_data : np.ndarray A matrix of size "number of time bins" x "number of neurons", representing the number of spikes in each time bin for each neuron. bins_before : int @@ -46,25 +53,19 @@ def get_spikes_with_history(neural_data, bins_before, bins_after, bins_current=1 bins_after : int How many bins of neural data after the output are used for decoding. bins_current : int, optional - Whether to use the concurrent time bin of neural data for decoding. - Default is 1. + Whether to use the concurrent time bin of neural data for decoding, by + default 1. Returns ------- - numpy.ndarray + np.ndarray A matrix of size "number of total time bins" x "number of surrounding time bins used for prediction" x "number of neurons". - For every time bin, there are the firing rates of all neurons from the - specified number of time bins before (and after). """ - num_examples = neural_data.shape[0] # Number of total time bins we have neural data for - num_neurons = neural_data.shape[1] # Number of neurons - surrounding_bins = bins_before + bins_after + bins_current # Number of surrounding time bins used for prediction - X = np.zeros([num_examples, surrounding_bins, num_neurons]) # Initialize covariate matrix with zeros - - # Loop through each time bin, and collect the spikes occurring in surrounding time bins - # Note: The first "bins_before" and last "bins_after" rows of X will remain filled with zeros, - # since they don't get filled below due to insufficient preceding or succeeding bins. + num_examples, num_neurons = neural_data.shape + surrounding_bins = bins_before + bins_after + bins_current + X = np.zeros([num_examples, surrounding_bins, num_neurons]) + for i in range(num_examples - bins_before - bins_after): start_idx = i end_idx = start_idx + surrounding_bins @@ -72,7 +73,31 @@ def get_spikes_with_history(neural_data, bins_before, bins_after, bins_current=1 return X -def _get_trial_spikes_with_no_overlap_history(X, bins_before, bins_after, bins_current): +def _get_trial_spikes_with_no_overlap_history( + X: NDArray, + bins_before: int, + bins_after: int, + bins_current: int + ) -> List[NDArray]: + """ + Get trial spikes with no overlap history. + + Parameters + ---------- + X : NDArray + Input binned spike data. + bins_before : int + Number of bins before the current bin. + bins_after : int + Number of bins after the current bin. + bins_current : int + Number of current bins. + + Returns + ------- + List[NDArray] + List of trial covariates with no overlap history. + """ nonoverlap_trial_covariates = [] if X.ndim == 2: X_cov = get_spikes_with_history( @@ -87,9 +112,42 @@ def _get_trial_spikes_with_no_overlap_history(X, bins_before, bins_after, bins_c return nonoverlap_trial_covariates def format_trial_segs_nsv( - nsv_train_normed, nsv_rest_normed, bv_train, bv_rest, predict_bv, - bins_before=0, bins_current=1, bins_after=0, - ):# -> tuple[NDArray, list, ndarray[Any, dtype], list, Any | nda...: + nsv_train_normed: List[NDArray], + nsv_rest_normed: List[NDArray], + bv_train: NDArray, + bv_rest: List[NDArray], + predict_bv: List[int], + bins_before: int = 0, + bins_current: int = 1, + bins_after: int = 0 +) -> Tuple[NDArray, List[NDArray], NDArray, List[NDArray], NDArray, List[NDArray]]: + """ + Format trial segments for neural state vectors. + + Parameters + ---------- + nsv_train_normed : List[NDArray] + Normalized neural state vectors for training. + nsv_rest_normed : List[NDArray] + Normalized neural state vectors for rest. + bv_train : NDArray + Behavioral state vectors for training. + bv_rest : List[NDArray] + Behavioral state vectors for rest. + predict_bv : List[int] + Indices of behavioral state vectors to predict. + bins_before : int, optional + Number of bins before the current bin, by default 0. + bins_current : int, optional + Number of current bins, by default 1. + bins_after : int, optional + Number of bins after the current bin, by default 0. + + Returns + ------- + Tuple[NDArray, List[NDArray], NDArray, List[NDArray], NDArray, List[NDArray]] + Formatted trial segments for neural state vectors. + """ is_2D = nsv_train_normed[0].ndim == 1 # Format for RNNs: covariate matrix including spike history from previous bins X_train = np.concatenate(_get_trial_spikes_with_no_overlap_history( @@ -121,7 +179,28 @@ def format_trial_segs_nsv( return X_train, X_rest, X_flat_train, X_flat_rest, y_train, y_rest -def zscore_trial_segs(train, rest_feats=None, normparams=None): +def zscore_trial_segs( + train: NDArray, + rest_feats: Optional[List[NDArray]] = None, + normparams: Optional[Dict[str, Any]] = None +) -> Tuple[NDArray, List[NDArray], Dict[str, Any]]: + """ + Z-score trial segments. + + Parameters + ---------- + train : NDArray + Training data. + rest_feats : Optional[List[NDArray]], optional + Rest features, by default None. + normparams : Optional[Dict[str, Any]], optional + Normalization parameters, by default None. + + Returns + ------- + Tuple[NDArray, List[NDArray], Dict[str, Any]] + Normalized train data, normalized rest features, and normalization parameters. + """ is_2D = train[0].ndim == 1 concat_train = train if is_2D else np.concatenate(train) train_mean = normparams['X_train_mean'] if normparams is not None else bn.nanmean(concat_train, axis=0) @@ -159,7 +238,46 @@ def zscore_trial_segs(train, rest_feats=None, normparams=None): X_train_notnan_mask=train_notnan_cols, ) -def normalize_format_trial_segs(nsv_train, nsv_rest, bv_train, bv_rest, predict_bv=[4,5], bins_before=0, bins_current=1, bins_after=0, normparams=None): +def normalize_format_trial_segs( + nsv_train: NDArray, + nsv_rest: List[NDArray], + bv_train: NDArray, + bv_rest: List[NDArray], + predict_bv: List[int] = [4, 5], + bins_before: int = 0, + bins_current: int = 1, + bins_after: int = 0, + normparams: Optional[Dict[str, Any]] = None +) -> Tuple[NDArray, NDArray, NDArray, List[Tuple[NDArray, NDArray, NDArray]], Dict[str, Any]]: + """ + Normalize and format trial segments. + + Parameters + ---------- + nsv_train : NDArray + Neural state vectors for training. + nsv_rest : List[NDArray] + Neural state vectors for rest. + bv_train : NDArray + Behavioral state vectors for training. + bv_rest : List[NDArray] + Behavioral state vectors for rest. + predict_bv : List[int], optional + Indices of behavioral state vectors to predict, by default [4, 5]. + bins_before : int, optional + Number of bins before the current bin, by default 0. + bins_current : int, optional + Number of current bins, by default 1. + bins_after : int, optional + Number of bins after the current bin, by default 0. + normparams : Optional[Dict[str, Any]], optional + Normalization parameters, by default None. + + Returns + ------- + Tuple[NDArray, NDArray, NDArray, List[Tuple[NDArray, NDArray, NDArray]], Dict[str, Any]] + Normalized and formatted trial segments. + """ nsv_train_normed, nsv_rest_normed, norm_params = zscore_trial_segs(nsv_train, nsv_rest, normparams) (X_train, X_rest, X_flat_train, X_flat_rest, y_train, y_rest @@ -179,9 +297,48 @@ def normalize_format_trial_segs(nsv_train, nsv_rest, bv_train, bv_rest, predict_ return X_train, X_flat_train, y_train, tuple(zip(X_rest, X_flat_rest, y_centered_rest)), norm_params def minibatchify( - Xtrain, ytrain, Xval, yval, Xtest, ytest, seed=0, - batch_size=128, num_workers=5, modeltype='MLP' - ): + Xtrain: NDArray, + ytrain: NDArray, + Xval: NDArray, + yval: NDArray, + Xtest: NDArray, + ytest: NDArray, + seed: int = 0, + batch_size: int = 128, + num_workers: int = 5, + modeltype: str = 'MLP' + ) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + """ + Create minibatches for training, validation, and testing. + + Parameters + ---------- + Xtrain : NDArray + Training features. + ytrain : NDArray + Training labels. + Xval : NDArray + Validation features. + yval : NDArray + Validation labels. + Xtest : NDArray + Test features. + ytest : NDArray + Test labels. + seed : int, optional + Random seed, by default 0. + batch_size : int, optional + Batch size, by default 128. + num_workers : int, optional + Number of workers for data loading, by default 5. + modeltype : str, optional + Type of model, by default 'MLP'. + + Returns + ------- + Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader] + DataLoaders for training, validation, and testing. + """ g_seed = torch.Generator() g_seed.manual_seed(seed) train = torch.utils.data.TensorDataset( @@ -211,7 +368,28 @@ def minibatchify( return train_loader, val_loader, test_loader -def normalize_labels(y_train, y_val, y_test): +def normalize_labels( + y_train: NDArray, + y_val: NDArray, + y_test: NDArray + ) -> Tuple[Tuple[NDArray, NDArray, NDArray], int]: + """ + Normalize labels to integers in [0, n_classes). + + Parameters + ---------- + y_train : NDArray + Training labels. + y_val : NDArray + Validation labels. + y_test : NDArray + Test labels. + + Returns + ------- + Tuple[Tuple[NDArray, NDArray, NDArray], int] + Normalized labels and number of classes. + """ # map labels to integers in [0, n_classes) uniq_labels = np.unique(np.concatenate((y_train, y_val, y_test))) n_classes = len(uniq_labels) @@ -221,16 +399,74 @@ def normalize_labels(y_train, y_val, y_test): y_test = np.vectorize(lambda v: uniq_labels_idx_map[v])(y_test) return (y_train, y_val, y_test), n_classes -def create_model(hyperparams): +def create_model(hyperparams: Dict[str, Any]) -> Tuple[Any, pl.LightningModule]: + """ + Create a model based on the given hyperparameters. + + Parameters + ---------- + hyperparams : Dict[str, Any] + Dictionary containing model hyperparameters. + + Returns + ------- + Tuple[Any, pl.LightningModule] + The decoder class and instantiated model. + """ decoder = eval(f"{hyperparams['model']}") model = decoder(**hyperparams['model_args']) + if 'LSTM' in hyperparams['model']: model.init_hidden(hyperparams['batch_size']) model.hidden_state = model.hidden_state.to(hyperparams['device']) model.cell_state = model.cell_state.to(hyperparams['device']) + return decoder, model -def preprocess_data(hyperparams, ohe, nsv_train, nsv_val, nsv_test, bv_train, bv_val, bv_test, foldnormparams=None): +def preprocess_data( + hyperparams: Dict[str, Any], + ohe: sklearn.preprocessing.OneHotEncoder, + nsv_train: NDArray, + nsv_val: NDArray, + nsv_test: NDArray, + bv_train: NDArray, + bv_val: NDArray, + bv_test: NDArray, + foldnormparams: Optional[Dict[str, Any]] = None + ) -> Tuple[ + Tuple[NDArray, NDArray, NDArray, NDArray, NDArray, NDArray], + Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader], + Dict[str, Any] + ]: + """ + Preprocess the data for model training and evaluation. + + Parameters + ---------- + hyperparams : Dict[str, Any] + Dictionary containing hyperparameters. + ohe : OneHotEncoder + One-hot encoder for categorical variables. + nsv_train : NDArray + Neural state vectors for training. + nsv_val : NDArray + Neural state vectors for validation. + nsv_test : NDArray + Neural state vectors for testing. + bv_train : NDArray + Behavioral state vectors for training. + bv_val : NDArray + Behavioral state vectors for validation. + bv_test : NDArray + Behavioral state vectors for testing. + foldnormparams : Optional[Dict[str, Any]], optional + Normalization parameters for the current fold, by default None. + + Returns + ------- + Tuple[Tuple[NDArray, NDArray, NDArray, NDArray, NDArray, NDArray], Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader], Dict[str, Any]] + Preprocessed data, data loaders, and normalization parameters. + """ bins_before = hyperparams['bins_before'] bins_current = hyperparams['bins_current'] bins_after = hyperparams['bins_after'] @@ -344,7 +580,34 @@ def preprocess_data(hyperparams, ohe, nsv_train, nsv_val, nsv_test, bv_train, bv return (X_train, y_train, X_val, y_val, X_test, y_test), (train_loader, val_loader, test_loader), fold_norm_params -def evaluate_model(hyperparams, ohe, predictor, X_test, y_test): +def evaluate_model( + hyperparams: Dict[str, Any], + ohe: sklearn.preprocessing.OneHotEncoder, + predictor: torch.nn.Module, + X_test: NDArray, + y_test: NDArray + ) -> Tuple[Dict[str, float], NDArray]: + """ + Evaluate the model on test data. + + Parameters + ---------- + hyperparams : Dict[str, Any] + Dictionary containing hyperparameters. + ohe : OneHotEncoder + One-hot encoder for categorical variables. + predictor : torch.nn.Module + The trained model. + X_test : NDArray + Test features. + y_test : NDArray + Test labels. + + Returns + ------- + Tuple[Dict[str, float], NDArray] + Evaluation metrics and model predictions. + """ if hyperparams['model'] in ('M2MLSTM', 'NDT'): out_dim = hyperparams['model_args']['out_dim'] with torch.no_grad(): @@ -384,7 +647,20 @@ def evaluate_model(hyperparams, ohe, predictor, X_test, y_test): metrics = dict(coeff_determination=coeff_determination, rmse=rmse) return metrics, bv_preds_fold -def shuffle_nsv_intrialsegs(nsv_trialsegs): +def shuffle_nsv_intrialsegs(nsv_trialsegs: List[pd.DataFrame]) -> NDArray: + """ + Shuffle neural state variables within trial segments. + + Parameters + ---------- + nsv_trialsegs : List[pd.DataFrame] + List of neural state variable trial segments. + + Returns + ------- + NDArray + Shuffled neural state variables. + """ nsv_shuffled_intrialsegs = [] for nsv_tseg in nsv_trialsegs: # shuffle the data @@ -393,15 +669,18 @@ def shuffle_nsv_intrialsegs(nsv_trialsegs): ) return np.asarray(nsv_shuffled_intrialsegs, dtype=object) -def train_model(partitions, hyperparams, resultspath=None, stop_partition=None): - """Generic function to train a DNN model on the given data partitions. - - In-built caching & checkpointing is used to save the best model based on the - validation loss. +def train_model( + partitions: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]], + hyperparams: Dict[str, Any], + resultspath: Optional[str] = None, + stop_partition: Optional[int] = None +) -> Tuple[List[np.ndarray], List[Any], List[Dict[str, Any]], Dict[str, List[float]]]: + """ + Train a DNN model on the given data partitions with in-built caching & checkpointing. Parameters ---------- - partitions : array-like + partitions : List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]] K-fold partitions of the data with the following format: [(nsv_train, bv_train, nsv_val, bv_val, nsv_test, bv_test), ...] Each element of the list is a tuple of numpy arrays containing the with @@ -410,83 +689,11 @@ def train_model(partitions, hyperparams, resultspath=None, stop_partition=None): (ntrials, nbins, nfeats) where nfeats is the number of neurons for the neural state vectors and number of behavioral features to be predicted for the behavioral variables. - hyperparams : dict - Dictionary containing the hyperparameters for the model training. The - dictionary should contain the following keys: - - `model`: str, the type of the model to be trained. Multi-layer - Perceptron (MLP), Long Short-Term Memory (LSTM), many-to-many LSTM - (M2MLSTM), Transformer (NDT). - - `model_args`: dict, the arguments to be passed to the model - constructor. The arguments should be in the format expected by the - model constructor. - - `in_dim`: The number of input features. - - `out_dim`: The number of output features. - - `hidden_dims`: The number of hidden units each hidden layer of the - model. Can also take float values to specify the dropout rate. - - For LSTM and M2MLSTM, it should be a tuple of the hidden size, - the number of layers, and the dropout rate. - If the model is an MLP, it should be a list of hidden layer - sizes which can also take float values to specify the dropout - rate. - - If the model is an LSTM or M2MLSTM, it should be a list of the - hidden layer size, the number of layers, and the dropout rate. - - If the model is an NDT, it should be a list of the hidden - layer size, the number of layers, the number of attention heads, - the dropout rate for the encoder layer, and the dropout rate - applied before the decoder layer. - - `max_context_len`: The maximum context length for the transformer - model. Only used if the model is an NDT. - - `args`: - - `clf`: If True, the model is a classifier; otherwise, it is a - regressor. - - `activations`: The activation functions for each layer. - - `criterion`: The loss function to optimize. - - `epochs`: The number of complete passes through the training - dataset. - - `lr`: Controls how much to change the model in response to the - estimated error each time the model weights are updated. A - smaller value ensures stable convergence but may slow down - training, while a larger value speeds up training but risks - overshooting. - - `base_lr`: The initial learning rate for the learning rate - scheduler. - - `max_grad_norm`: The maximum norm of the gradients. - - `iters_to_accumulate`: The number of iterations to accumulate - gradients. - - `weight_decay`: The L2 regularization strength. - - `num_training_batches`: The number of training batches. If - None, the number of batches is calculated based on the batch - size and the length of the training data. - - `scheduler_step_size_multiplier`: The multiplier for the - learning rate scheduler step size. Higher values lead to - faster learning rate decay. - - `bins_before`: int, the number of bins before the current bin to - include in the input data. - - `bins_current`: int, the number of bins in the current time bin to - include in the input data. - - `bins_after`: int, the number of bins after the current bin to include - in the input data. - - `behaviors`: list, the indices of the columns of behavioral features - to be predicted. Selected behavioral variable must have homogenous - data types across all features (continuous for regression and - categorical for classification) - - `batch_size`: int, the number of training examples utilized in one - iteration. Larger batch sizes offer stable gradient estimates but - require more memory, while smaller batches introduce noise that can - help escape local minima. - - When using M2MLSTM or NDT and input trials are of inconsistents - lengths, the batch size should be set to 1. - - M2MLSTM does not support batch_size != 1. - - `num_workers`: int, The number of parallel processes to use for data - loading. Increasing the number of workers can speed up data loading - but may lead to memory issues. Too many workers can also slow down - the training process due to contention for resources. - - `device`: str, the device to use for training. Should be 'cuda' or - 'cpu'. - - `seed`: int, the random seed for reproducibility. - resultspath : str or None, optional + hyperparams : Dict[str, Any] + Dictionary containing the hyperparameters for the model training. + resultspath : Optional[str], default=None Path to the directory where the trained models and logs will be saved. - stop_partition : int, optional + stop_partition : Optional[int], default=None Index of the partition to stop training at. Only useful for debugging, by default None @@ -496,6 +703,80 @@ def train_model(partitions, hyperparams, resultspath=None, stop_partition=None): Tuple containing the predicted behavioral variables for each fold, the trained models for each fold, the normalization parameters for each fold, and the evaluation metrics for each fold. + + Notes + ----- + The hyperparameters dictionary should contain the following keys: + - `model`: str, the type of the model to be trained. Multi-layer + Perceptron (MLP), Long Short-Term Memory (LSTM), many-to-many LSTM + (M2MLSTM), Transformer (NDT). + - `model_args`: dict, the arguments to be passed to the model constructor. + The arguments should be in the format expected by the model constructor. + - `in_dim`: The number of input features. + - `out_dim`: The number of output features. + - `hidden_dims`: The number of hidden units each hidden layer of the + model. Can also take float values to specify the dropout rate. + - For LSTM and M2MLSTM, it should be a tuple of the hidden size, + the number of layers, and the dropout rate. + If the model is an MLP, it should be a list of hidden layer + sizes which can also take float values to specify the dropout + rate. + - If the model is an LSTM or M2MLSTM, it should be a list of the + hidden layer size, the number of layers, and the dropout rate. + - If the model is an NDT, it should be a list of the hidden layer + size, the number of layers, the number of attention heads, the + dropout rate for the encoder layer, and the dropout rate applied + before the decoder layer. + - `max_context_len`: The maximum context length for the transformer + model. Only used if the model is an NDT. + - `args`: + - `clf`: If True, the model is a classifier; otherwise, it is a + regressor. + - `activations`: The activation functions for each layer. + - `criterion`: The loss function to optimize. + - `epochs`: The number of complete passes through the training + dataset. + - `lr`: Controls how much to change the model in response to the + estimated error each time the model weights are updated. A + smaller value ensures stable convergence but may slow down + training, while a larger value speeds up training but risks + overshooting. + - `base_lr`: The initial learning rate for the learning rate + scheduler. + - `max_grad_norm`: The maximum norm of the gradients. + - `iters_to_accumulate`: The number of iterations to accumulate + gradients. + - `weight_decay`: The L2 regularization strength. + - `num_training_batches`: The number of training batches. If + None, the number of batches is calculated based on the batch + size and the length of the training data. + - `scheduler_step_size_multiplier`: The multiplier for the + learning rate scheduler step size. Higher values lead to + faster learning rate decay. + - `bins_before`: int, the number of bins before the current bin to + include in the input data. + - `bins_current`: int, the number of bins in the current time bin to + include in the input data. + - `bins_after`: int, the number of bins after the current bin to include + in the input data. + - `behaviors`: list, the indices of the columns of behavioral features + to be predicted. Selected behavioral variable must have homogenous + data types across all features (continuous for regression and + categorical for classification) + - `batch_size`: int, the number of training examples utilized in one + iteration. Larger batch sizes offer stable gradient estimates but + require more memory, while smaller batches introduce noise that can + help escape local minima. + - When using M2MLSTM or NDT and input trials are of inconsistents + lengths, the batch size should be set to 1. + - M2MLSTM does not support batch_size != 1. + - `num_workers`: int, The number of parallel processes to use for data + loading. Increasing the number of workers can speed up data loading + but may lead to memory issues. Too many workers can also slow down + the training process due to contention for resources. + - `device`: str, the device to use for training. Should be 'cuda' or + 'cpu'. + - `seed`: int, the random seed for reproducibility. """ ohe = sklearn.preprocessing.OneHotEncoder() bv_preds_folds = [] @@ -585,7 +866,34 @@ def train_model(partitions, hyperparams, resultspath=None, stop_partition=None): break return bv_preds_folds, bv_models_folds, norm_params_folds, metrics_folds -def predict_models_folds(partitions, hyperparams, bv_models_folds, foldnormparams): +def predict_models_folds( + partitions: List[Tuple[NDArray, NDArray, NDArray, NDArray, NDArray, NDArray]], + hyperparams: Dict[str, Any], + bv_models_folds: List[Any], + foldnormparams: List[Dict[str, Any]] + ) -> Tuple[List[NDArray], Dict[str, List[float]]]: + """ + Predict and evaluate models across multiple folds. + + Parameters + ---------- + partitions : List[Tuple[NDArray, NDArray, NDArray, NDArray, NDArray, NDArray]] + List of data partitions for each fold. Each partition contains: + (nsv_train, bv_train, nsv_val, bv_val, nsv_test, bv_test) + hyperparams : Dict[str, Any] + Dictionary of hyperparameters for the models. + bv_models_folds : List[Any] + List of trained models for each fold. + foldnormparams : List[Dict[str, Any]] + List of normalization parameters for each fold. + + Returns + ------- + Tuple[List[NDArray], Dict[str, List[float]]] + A tuple containing: + - List of predictions for each fold + - Dictionary of evaluation metrics for each fold + """ ohe = sklearn.preprocessing.OneHotEncoder() bv_preds_folds = [] metrics_folds = dict() diff --git a/neuro_py/ensemble/decoding/preprocess.py b/neuro_py/ensemble/decoding/preprocess.py index 8aa3428..7bc9e48 100644 --- a/neuro_py/ensemble/decoding/preprocess.py +++ b/neuro_py/ensemble/decoding/preprocess.py @@ -1,16 +1,50 @@ +from typing import List, Tuple, Union + import numpy as np -import sklearn.model_selection +import pandas as pd + +from sklearn.model_selection import StratifiedKFold + + +def split_data(trial_nsvs: np.ndarray, splitby: np.ndarray, trainsize: float = 0.8, seed: int = 0) -> List[np.ndarray]: + """ + Split data into stratified folds. + Parameters + ---------- + trial_nsvs : np.ndarray + Neural state vectors for trials. + splitby : np.ndarray + Labels for stratification. + trainsize : float, optional + Proportion of data to use for training, by default 0.8 + seed : int, optional + Random seed for reproducibility, by default 0 -def split_data(trial_nsvs, splitby, trainsize=.8, seed=0): + Returns + ------- + List[np.ndarray] + List of indices for each fold. + """ n_splits = int(np.round(1 / ((1 - trainsize) / 2))) - skf = sklearn.model_selection.StratifiedKFold( - n_splits=n_splits, shuffle=True, random_state=seed) - folds = [ - fold_indices for _, fold_indices in skf.split(trial_nsvs, splitby)] + skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed) + folds = [fold_indices for _, fold_indices in skf.split(trial_nsvs, splitby)] return folds -def partition_indices(folds): +def partition_indices(folds: List[np.ndarray]) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: + """ + Partition indices into train, validation, and test sets. + + Parameters + ---------- + folds : List[np.ndarray] + Indices for each fold. + + Returns + ------- + List[Tuple[np.ndarray, np.ndarray, np.ndarray]] + Train, validation, and test indices. + """ partition_mask = np.zeros(len(folds), dtype=int) partition_mask[0:2] = (2, 1) folds_arr = np.asarray(folds, dtype=object) @@ -25,45 +59,53 @@ def partition_indices(folds): partitions_indices.append((train_indices, val_indices, test_indices)) return partitions_indices -def partition_sets(partitions_indices, nsv_trial_segs, bv_trial_segs): - """Partition neural state vectors and behavioral variables into train, +def partition_sets( + partitions_indices: List[Tuple[np.ndarray, np.ndarray, np.ndarray]], + nsv_trial_segs: Union[np.ndarray, pd.DataFrame], + bv_trial_segs: Union[np.ndarray, pd.DataFrame] +) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + """ + Partition neural state vectors and behavioral variables into train, validation, and test sets. Parameters ---------- - partitions_indices : list[tuple[np.ndarray]] + partitions_indices : List[Tuple[np.ndarray, np.ndarray, np.ndarray]] List of tuples containing indices of divided trials into train, validation, and test sets. - nsv_trial_segs : np.ndarray[pd.DataFrame] or pd.DataFrame + nsv_trial_segs : Union[np.ndarray, pd.DataFrame] Neural state vectors for each trial. Shape: [n_trials, n_timepoints, n_neurons] or [n_timepoints, n_neurons] - bv_trial_segs : np.ndarray[pd.DataFrame] or pd.DataFrame + bv_trial_segs : Union[np.ndarray, pd.DataFrame] Behavioral variables for each trial. Shape: [n_trials, n_timepoints, n_bvars] or [n_timepoints, n_bvars] Returns ------- - partitions : list[tuple[np.ndarray]] + List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]] List of tuples containing train, validation, and test sets for neural state vectors and behavioral variables. """ partitions = [] - is_2D = nsv_trial_segs[0].ndim == 1 + is_2D = nsv_trial_segs.ndim == 1 for (train_indices, val_indices, test_indices) in partitions_indices: - train = nsv_trial_segs.loc[train_indices] if is_2D else \ - np.take(nsv_trial_segs, train_indices) - val = nsv_trial_segs.loc[val_indices] if is_2D else \ - np.take(nsv_trial_segs, val_indices) - test = nsv_trial_segs.loc[test_indices] if is_2D else \ - np.take(nsv_trial_segs, test_indices) + if is_2D: + if isinstance(nsv_trial_segs, pd.DataFrame): + nsv_trial_segs = nsv_trial_segs.loc + bv_trial_segs = bv_trial_segs.loc + train = nsv_trial_segs[train_indices] + val = nsv_trial_segs[val_indices] + test = nsv_trial_segs[test_indices] + train_bv = bv_trial_segs[train_indices] + val_bv = bv_trial_segs[val_indices] + test_bv = bv_trial_segs[test_indices] + else: + train = np.take(nsv_trial_segs, train_indices, axis=0) + val = np.take(nsv_trial_segs, val_indices, axis=0) + test = np.take(nsv_trial_segs, test_indices, axis=0) + train_bv = np.take(bv_trial_segs, train_indices, axis=0) + val_bv = np.take(bv_trial_segs, val_indices, axis=0) + test_bv = np.take(bv_trial_segs, test_indices, axis=0) - train_bv, val_bv, test_bv = ( - bv_trial_segs.loc[train_indices] if is_2D else \ - np.take(bv_trial_segs, train_indices), - bv_trial_segs.loc[val_indices] if is_2D else \ - np.take(bv_trial_segs, val_indices), - bv_trial_segs.loc[test_indices] if is_2D else \ - np.take(bv_trial_segs, test_indices) - ) partitions.append((train, train_bv, val, val_bv, test, test_bv)) return partitions diff --git a/neuro_py/ensemble/decoding/transformer.py b/neuro_py/ensemble/decoding/transformer.py index b494e8c..845bf8e 100644 --- a/neuro_py/ensemble/decoding/transformer.py +++ b/neuro_py/ensemble/decoding/transformer.py @@ -1,60 +1,105 @@ +from typing import Tuple, Dict, Optional + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import lightning as L - class PositionalEncoding(nn.Module): - def __init__(self, in_dim, max_context_len, args): + """ + Positional Encoding module for Transformer models. + + Parameters + ---------- + in_dim : int + Input dimension of the model + max_context_len : int + Maximum context length + args : Dict + Additional arguments (not used in this implementation) + + Attributes + ---------- + pe : torch.Tensor + Positional encoding tensor + """ + def __init__(self, in_dim: int, max_context_len: int, args: Dict): super().__init__() - pe = torch.zeros(max_context_len, in_dim) # * Can optim to empty + pe = torch.zeros(max_context_len, in_dim) position = torch.arange(0, max_context_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, in_dim, 2).float() * (-np.log(1e4) / in_dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) # t x 1 x d self.register_buffer('pe', pe) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding to the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (seq_len, batch_size, in_dim) + + Returns + ------- + torch.Tensor + Input tensor with added positional encoding + """ self.pe = self.pe.to(x.device) x = x + self.pe[:x.size(0), :] # t x 1 x d, # t x b x d return x class NDT(L.LightningModule): - """Transformer encoder-based dynamical systems decoder. - * Trained on MLM loss - * Returns loss & predicted rates.""" - def __init__(self, in_dim=100, out_dim=2, hidden_dims=[400, 1, 1, .0, .0], max_context_len=2, args=None): - """Constructs a Transformer-based decoder. + """ + Transformer encoder-based dynamical systems decoder. - Parameters - ---------- - in_dim : int - Dimensionality of input data - out_dim : int - Number of output columns - hidden_dims : list - Containing the architectural parameters of the model - (dim_feedforward, num_layers, nhead, dropout, rate_dropout) - max_context_len : int - Maximum context length - args : dict - Dictionary containing the hyperparameters of the model - """ + This class implements a Transformer-based decoder trained on MLM loss. + It returns loss and predicted rates. + + Parameters + ---------- + in_dim : int, optional + Dimensionality of input data, by default 100 + out_dim : int, optional + Number of output columns, by default 2 + hidden_dims : Tuple[int], optional + Architectural parameters of the model + (dim_feedforward, num_layers, nhead, dropout, rate_dropout), + by default [400, 1, 1, 0.0, 0.0] + max_context_len : int, optional + Maximum context length, by default 2 + args : Optional[Dict], optional + Dictionary containing the hyperparameters of the model, by default None + + Attributes + ---------- + pos_encoder : PositionalEncoding + Positional encoding module + transformer_encoder : nn.TransformerEncoder + Transformer encoder module + rate_dropout : nn.Dropout + Dropout layer for rates + decoder : nn.Sequential + Decoder network + src_mask : Dict[str, torch.Tensor] + Dictionary to store source masks for different devices + """ + def __init__(self, in_dim: int = 100, out_dim: int = 2, + hidden_dims: Tuple[int] = (400, 1, 1, 0.0, 0.0), + max_context_len: int = 2, args: Optional[Dict] = None): super().__init__() self.save_hyperparameters() self.max_context_len = max_context_len self.in_dim = in_dim - self.args = args - activations = nn.CELU if self.args['activations'] is None else self.args['activations'] + self.args = args if args is not None else {} + activations = nn.CELU if self.args.get('activations') is None else self.args['activations'] - self.src_mask = {} # full context, by default + self.src_mask: Dict[str, torch.Tensor] = {} - # self.scale = np.sqrt(in_dim) - - self.pos_encoder = PositionalEncoding(in_dim, max_context_len, args) + self.pos_encoder = PositionalEncoding(in_dim, max_context_len, self.args) encoder_lyr = nn.TransformerEncoderLayer( in_dim, @@ -73,7 +118,11 @@ def __init__(self, in_dim=100, out_dim=2, hidden_dims=[400, 1, 1, .0, .0], max_c nn.Linear(in_dim, 16), activations(), nn.Linear(16, out_dim) ) - def init_params(m): + self._init_params() + + def _init_params(self) -> None: + """Initialize the parameters of the decoder.""" + def init_params(m: nn.Module) -> None: if isinstance(m, nn.Linear): torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu') if m.bias is not None: @@ -82,16 +131,21 @@ def init_params(m): nn.init.uniform_(m.bias, -bound, bound) # LeCunn init self.decoder.apply(init_params) - - def forward(self, x, mask_labels=None): - """Forward pass of the model. + def forward(self, x: torch.Tensor, mask_labels: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of the model. Parameters ---------- - x : torch.tensor (BxLxN) - Input data - mask_labels : torch.tensor (LxL) - Masking labels for the input data + x : torch.Tensor + Input data of shape (batch_size, seq_len, in_dim) + mask_labels : Optional[torch.Tensor], optional + Masking labels for the input data, by default None + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, seq_len, out_dim) """ x = x.permute(1, 0, 2) # LxBxN x = self.pos_encoder(x) @@ -99,11 +153,24 @@ def forward(self, x, mask_labels=None): z = self.transformer_encoder(x, x_mask) z = self.rate_dropout(z) out = self.decoder(z).permute(1, 0, 2) # B x L x out_dim - if self.args['clf']: + if self.args.get('clf', False): out = F.log_softmax(out, dim=-1) return out - def _get_or_generate_context_mask(self, src): + def _get_or_generate_context_mask(self, src: torch.Tensor) -> torch.Tensor: + """ + Get or generate the context mask for the input tensor. + + Parameters + ---------- + src : torch.Tensor + Input tensor + + Returns + ------- + torch.Tensor + Context mask for the input tensor + """ context_forward = 4 size = src.size(0) # T mask = (torch.triu(torch.ones(size, size, device=src.device), diagonal=-context_forward) == 1).transpose(0, 1) @@ -111,35 +178,101 @@ def _get_or_generate_context_mask(self, src): self.src_mask[str(src.device)] = mask return self.src_mask[str(src.device)] - def _step(self, batch, batch_idx) -> torch.Tensor: - xs, ys = batch # unpack the batch - B, L, N = xs.shape + def _step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Perform a single step (forward pass + loss calculation). + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ + xs, ys = batch outs = self(xs) loss = self.args['criterion'](outs, ys) return loss - def training_step(self, batch, batch_idx) -> torch.Tensor: + def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Lightning method for training step. + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('train_loss', loss) return loss - def validation_step(self, batch, batch_idx) -> torch.Tensor: + def validation_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Lightning method for validation step. + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('val_loss', loss) return loss - def test_step(self, batch, batch_idx) -> torch.Tensor: + def test_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: + """ + Lightning method for test step. + + Parameters + ---------- + batch : tuple + Batch of input data and labels + batch_idx : int + Index of the current batch + + Returns + ------- + torch.Tensor + Computed loss + """ loss = self._step(batch, batch_idx) self.log('test_loss', loss) return loss - def configure_optimizers(self): - args = self.args + def configure_optimizers(self) -> tuple: + """ + Configure optimizers and learning rate schedulers. + + Returns + ------- + tuple + List of optimizers and a list of scheduler configurations + """ optimizer = torch.optim.AdamW( - self.parameters(), weight_decay=args['weight_decay']) + self.parameters(), weight_decay=self.args['weight_decay']) scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, max_lr=args['lr'], - epochs=args['epochs'], + optimizer, max_lr=self.args['lr'], + epochs=self.args['epochs'], total_steps=self.trainer.estimated_stepping_batches ) lr_scheduler = {'scheduler': scheduler, 'interval': 'step'} diff --git a/tutorials/decoding.ipynb b/tutorials/decoding.ipynb index 18b09aa..2345340 100644 --- a/tutorials/decoding.ipynb +++ b/tutorials/decoding.ipynb @@ -1007,7 +1007,7 @@ "source": [ "decoder_type = 'MLP' # Select decoder type (e.g., MLP)\n", "hyperparams = dict(\n", - " batch_size=512*8,\n", + " batch_size=512*4,\n", " num_workers=5,\n", " model=decoder_type,\n", " model_args=dict(\n", @@ -1018,7 +1018,7 @@ " clf=False,\n", " activations=nn.CELU,\n", " criterion=F.mse_loss,\n", - " epochs=50,\n", + " epochs=10,\n", " lr=3e-2,\n", " base_lr=1e-2,\n", " max_grad_norm=1.,\n",