From d020d47b9b9e267812192df099d80999d9888337 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 22 Jan 2025 21:59:12 -0800 Subject: [PATCH 01/15] Add dropout option to MLP module --- neurobayes/flax_nets/convnet.py | 10 +++- neurobayes/flax_nets/deterministic_nn.py | 62 ++++++++++++++++-------- neurobayes/flax_nets/mlp.py | 32 ++++++++---- neurobayes/models/bnn.py | 36 ++++++++------ neurobayes/models/partial_bnn.py | 6 +-- 5 files changed, 98 insertions(+), 48 deletions(-) diff --git a/neurobayes/flax_nets/convnet.py b/neurobayes/flax_nets/convnet.py index fe4962f..7aa5c5a 100644 --- a/neurobayes/flax_nets/convnet.py +++ b/neurobayes/flax_nets/convnet.py @@ -28,6 +28,8 @@ class FlaxConvNet(nn.Module): target_dim: int activation: str = 'tanh' kernel_size: Union[int, Tuple[int, ...]] = 3 + hidden_dropout: float = 0.0 + output_dropout: float = 0.0 classification: bool = False # Explicit flag for classification tasks @nn.compact @@ -53,6 +55,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: hidden_dims=self.fc_layers, target_dim=self.target_dim, activation=self.activation, + hidden_dropout=self.hidden_dropout, + output_dropout=self.output_dropout, classification=self.classification)(x) return x @@ -64,6 +68,8 @@ class FlaxConvNet2Head(nn.Module): fc_layers: Sequence[int] target_dim: int activation: str = 'tanh' + hidden_dropout: float = 0.0 + output_dropout: float = 0.0 kernel_size: Union[int, Tuple[int, ...]] = 3 @nn.compact @@ -87,7 +93,9 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: mean, var = FlaxMLP2Head( hidden_dims=self.fc_layers, target_dim=self.target_dim, - activation=self.activation)(x) + activation=self.activation, + hidden_dropout=self.hidden_dropout, + output_dropout=self.output_dropout)(x) return mean, var diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index 7f7977c..7a8208b 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -86,7 +86,13 @@ def __init__(self, @partial(jax.jit, static_argnums=(0,)) def train_step(self, state, inputs, targets): """JIT-compiled training step""" - loss, grads = jax.value_and_grad(self.total_loss)(state.params, inputs, targets) + dropout_key = jax.random.PRNGKey(state.step) + loss, grads = jax.value_and_grad(self.total_loss)( + state.params, + inputs, + targets, + rngs={'dropout': dropout_key} + ) state = state.apply_gradients(grads=grads) return state, loss @@ -165,44 +171,60 @@ def average_params(self) -> Dict: def reset_swa(self): """Reset SWA collections""" self.params_history = [] - - def mse_loss(self, params: Dict, inputs: jnp.ndarray, - targets: jnp.ndarray) -> jnp.ndarray: - predictions = self.model.apply({'params': params}, inputs) - return jnp.mean((predictions - targets) ** 2) - - def heteroskedastic_loss(self, params: Dict, inputs: jnp.ndarray, - targets: jnp.ndarray) -> jnp.ndarray: - y_pred, y_var = self.model.apply({'params': params}, inputs) - return jnp.mean(0.5 * jnp.log(y_var) + 0.5 * (targets - y_pred)**2 / y_var) - - def cross_entropy_loss(self, params: Dict, inputs: jnp.ndarray, - targets: jnp.ndarray) -> jnp.ndarray: - logits = self.model.apply({'params': params}, inputs) - return -jnp.mean(jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)) def gaussian_prior(self, params: Dict) -> jnp.ndarray: l2_norm = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params)) return l2_norm / (2 * self.sigma**2) - def total_loss(self, params: Dict, inputs: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray: + def total_loss(self, params: Dict, inputs: jnp.ndarray, targets: jnp.ndarray, + rngs: Dict) -> jnp.ndarray: if self.loss == 'classification': - loss = self.cross_entropy_loss(params, inputs, targets) + loss = self.cross_entropy_loss(params, inputs, targets, rngs) else: loss_fn = self.mse_loss if self.loss == 'homoskedastic' else self.heteroskedastic_loss - loss = loss_fn(params, inputs, targets) + loss = loss_fn(params, inputs, targets, rngs) if self.map: prior_loss = self.gaussian_prior(params) / len(inputs) loss += prior_loss return loss + def mse_loss(self, params: Dict, inputs: jnp.ndarray, + targets: jnp.ndarray, rngs: Dict) -> jnp.ndarray: + predictions = self.model.apply( + {'params': params}, + inputs, + enable_dropout=True, + rngs=rngs + ) + return jnp.mean((predictions - targets) ** 2) + + def heteroskedastic_loss(self, params: Dict, inputs: jnp.ndarray, + targets: jnp.ndarray, rngs: Dict) -> jnp.ndarray: + y_pred, y_var = self.model.apply( + {'params': params}, + inputs, + enable_dropout=True, + rngs=rngs + ) + return jnp.mean(0.5 * jnp.log(y_var) + 0.5 * (targets - y_pred)**2 / y_var) + + def cross_entropy_loss(self, params: Dict, inputs: jnp.ndarray, + targets: jnp.ndarray, rngs: Dict) -> jnp.ndarray: + logits = self.model.apply( + {'params': params}, + inputs, + enable_dropout=True, + rngs=rngs + ) + return -jnp.mean(jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)) + def predict(self, X: jnp.ndarray) -> jnp.ndarray: X = self.set_data(X) return self._predict(self.state, X) @partial(jax.jit, static_argnums=(0,)) def _predict(self, state, X): - predictions = state.apply_fn({'params': state.params}, X) + predictions = state.apply_fn({'params': state.params}, X, enable_dropout=False) if self.loss == 'classification': return jax.nn.softmax(predictions) return predictions diff --git a/neurobayes/flax_nets/mlp.py b/neurobayes/flax_nets/mlp.py index 9a99ac5..3991d65 100644 --- a/neurobayes/flax_nets/mlp.py +++ b/neurobayes/flax_nets/mlp.py @@ -7,13 +7,16 @@ class MLPLayerModule(nn.Module): features: int activation: Any = None + dropout: float = 0.0 layer_name: str = 'dense' @nn.compact - def __call__(self, x): + def __call__(self, x, enable_dropout: bool = True): x = nn.Dense(features=self.features, name=self.layer_name)(x) if self.activation is not None: x = self.activation(x) + if self.dropout > 0: + x = nn.Dropout(rate=self.dropout)(x, deterministic=not enable_dropout) return x @@ -21,10 +24,13 @@ class FlaxMLP(nn.Module): hidden_dims: Sequence[int] target_dim: int activation: str = 'tanh' - classification: bool = False # Explicit flag for classification tasks + hidden_dropout: float = 0.0 + output_dropout: float = 0.0 + classification: bool = False + @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, enable_dropout: bool = True) -> jnp.ndarray: """Forward pass of the MLP""" # Set the activation function @@ -35,17 +41,19 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: layer = MLPLayerModule( features=hidden_dim, activation=activation_fn, + dropout=self.hidden_dropout, layer_name=f"Dense{i}" ) - x = layer(x) + x = layer(x, enable_dropout) # Output layer output_layer = MLPLayerModule( features=self.target_dim, activation=nn.softmax if self.classification else None, + dropout=self.output_dropout, layer_name=f"Dense{len(self.hidden_dims)}" ) - x = output_layer(x) + x = output_layer(x, enable_dropout) return x @@ -54,9 +62,12 @@ class FlaxMLP2Head(nn.Module): hidden_dims: Sequence[int] target_dim: int activation: str = 'tanh' + hidden_dropout: float = 0.0 + output_dropout: float = 0.0 @nn.compact - def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def __call__(self, x: jnp.ndarray, enable_dropout: bool = True + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Forward pass of the 2-headed MLP""" # Set the activation function @@ -67,24 +78,27 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: layer = MLPLayerModule( features=hidden_dim, activation=activation_fn, + dropout=self.hidden_dropout, layer_name=f"Dense{i}" ) - x = layer(x) + x = layer(x, enable_dropout) # Mean head mean_layer = MLPLayerModule( features=self.target_dim, activation=None, + dropout=self.output_dropout, layer_name="MeanHead" ) - mean = mean_layer(x) + mean = mean_layer(x, enable_dropout) # Variance head var_layer = MLPLayerModule( features=self.target_dim, activation=nn.softplus, + dropout=self.output_dropout, layer_name="VarianceHead" ) - variance = var_layer(x) + variance = var_layer(x, enable_dropout) return mean, variance \ No newline at end of file diff --git a/neurobayes/models/bnn.py b/neurobayes/models/bnn.py index edaf5fd..09e906c 100644 --- a/neurobayes/models/bnn.py +++ b/neurobayes/models/bnn.py @@ -68,14 +68,13 @@ def is_regression(self) -> bool: return self.num_classes is None def model(self, - X: jnp.ndarray, - y: jnp.ndarray = None, - priors_sigma: float = 1.0, - **kwargs) -> None: - """Unified BNN model for both regression and classification""" - + X: jnp.ndarray, + y: jnp.ndarray = None, + priors_sigma: float = 1.0, + **kwargs) -> None: + pretrained_priors = (flatten_params_dict(self.pretrained_priors) - if self.pretrained_priors is not None else None) + if self.pretrained_priors is not None else None) def prior(name, shape): if pretrained_priors is not None: @@ -86,17 +85,23 @@ def prior(name, shape): return dist.Normal(0., priors_sigma) input_shape = X.shape[1:] if X.ndim > 2 else (X.shape[-1],) + net = random_flax_module( - "nn", self.nn, input_shape=(1, *input_shape), prior=prior) + "nn", + self.nn, + input_shape=(1, *input_shape), + prior=prior, + ) - if self.is_regression: - # Regression case - mu = numpyro.deterministic("mu", net(X)) + if self.is_regression: # Regression case + mu = numpyro.deterministic( + "mu", + net(X, enable_dropout=False) + ) sig = numpyro.sample("sig", self.noise_prior) numpyro.sample("y", dist.Normal(mu, sig), obs=y) - else: - # Classification case - logits = net(X) + else: # Classification case + logits = net(X, enable_dropout=False) probs = numpyro.deterministic("probs", softmax(logits, axis=-1)) numpyro.sample("y", dist.Categorical(probs=probs), obs=y) @@ -266,7 +271,8 @@ def sample_from_posterior(self, ) -> jnp.ndarray: """Sample from posterior distribution at new inputs X_new""" predictive = Predictive( - self.model, samples, + self.model, + samples, return_sites=return_sites ) return predictive(rng_key, X_new) diff --git a/neurobayes/models/partial_bnn.py b/neurobayes/models/partial_bnn.py index eb67a46..1c41970 100644 --- a/neurobayes/models/partial_bnn.py +++ b/neurobayes/models/partial_bnn.py @@ -63,7 +63,7 @@ def prior(name, shape): layer_name = param_path[0] param_type = param_path[-1] # kernel or bias return dist.Normal(pretrained_priors[layer_name][param_type], priors_sigma) - + current_input = X # Track when we switch from conv to dense layers @@ -93,7 +93,7 @@ def prior(name, shape): input_shape=(1, *current_input.shape[1:]), prior=prior ) - current_input = net(current_input) + current_input = net(current_input, enable_dropout=False) else: params = { "params": { @@ -103,7 +103,7 @@ def prior(name, shape): } } } - current_input = layer.apply(params, current_input) + current_input = layer.apply(params, current_input, enable_dropout=False) if self.is_regression: # Regression case mu = numpyro.deterministic("mu", current_input) From abbf078775ca1a8963aa946f0cf719b5534c02e2 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Thu, 23 Jan 2025 10:56:18 -0800 Subject: [PATCH 02/15] Add dropout option for conv layers --- neurobayes/flax_nets/convnet.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/neurobayes/flax_nets/convnet.py b/neurobayes/flax_nets/convnet.py index 7aa5c5a..2cbf72b 100644 --- a/neurobayes/flax_nets/convnet.py +++ b/neurobayes/flax_nets/convnet.py @@ -10,14 +10,17 @@ class ConvLayerModule(nn.Module): input_dim: int kernel_size: Union[int, Tuple[int, ...]] activation: Any = None + dropout: float = 0.0 layer_name: str = None @nn.compact - def __call__(self, x): + def __call__(self, x, enable_dropout: bool = True): conv, pool = get_conv_and_pool_ops(self.input_dim, self.kernel_size) x = conv(features=self.features, name=self.layer_name)(x) if self.activation is not None: x = self.activation(x) + if self.dropout > 0: + x = nn.Dropout(rate=self.dropout)(x, deterministic=not enable_dropout) x = pool(x) return x @@ -28,12 +31,14 @@ class FlaxConvNet(nn.Module): target_dim: int activation: str = 'tanh' kernel_size: Union[int, Tuple[int, ...]] = 3 + conv_dropout: float = 0.0 hidden_dropout: float = 0.0 output_dropout: float = 0.0 - classification: bool = False # Explicit flag for classification tasks + classification: bool = False @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, enable_dropout: bool = True + ) -> jnp.ndarray: activation_fn = nn.tanh if self.activation == 'tanh' else nn.silu # Convolutional layers @@ -43,9 +48,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: input_dim=self.input_dim, kernel_size=self.kernel_size, activation=activation_fn, + dropout=self.conv_dropout, layer_name=f"Conv{i}" ) - x = conv_layer(x) + x = conv_layer(x, enable_dropout=enable_dropout) # Flatten the output for the fully connected layers x = x.reshape((x.shape[0], -1)) @@ -57,7 +63,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: activation=self.activation, hidden_dropout=self.hidden_dropout, output_dropout=self.output_dropout, - classification=self.classification)(x) + classification=self.classification + )(x, enable_dropout=enable_dropout) return x @@ -68,12 +75,15 @@ class FlaxConvNet2Head(nn.Module): fc_layers: Sequence[int] target_dim: int activation: str = 'tanh' + conv_dropout: float = 0.0 hidden_dropout: float = 0.0 output_dropout: float = 0.0 kernel_size: Union[int, Tuple[int, ...]] = 3 @nn.compact - def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def __call__(self, x: jnp.ndarray, enable_dropout: bool = True + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + activation_fn = nn.tanh if self.activation == 'tanh' else nn.silu # Convolutional layers @@ -83,9 +93,10 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: input_dim=self.input_dim, kernel_size=self.kernel_size, activation=activation_fn, + dropout=self.conv_dropout, layer_name=f"Conv{i}" ) - x = conv_layer(x) + x = conv_layer(x, enable_dropout=enable_dropout) # Flatten the output for the fully connected layers x = x.reshape((x.shape[0], -1)) @@ -95,7 +106,8 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: target_dim=self.target_dim, activation=self.activation, hidden_dropout=self.hidden_dropout, - output_dropout=self.output_dropout)(x) + output_dropout=self.output_dropout + )(x, enable_dropout=enable_dropout) return mean, var From 8d2f864164e7c7e51ce1b68e8b7b1b302a0b4d30 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Thu, 23 Jan 2025 12:25:04 -0800 Subject: [PATCH 03/15] Handle dropout properly in heteroskedastic BNNs --- neurobayes/models/bnn_heteroskedastic.py | 2 +- neurobayes/models/bnn_heteroskedastic_model.py | 2 +- neurobayes/models/partial_bnn_heteroskedastic.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neurobayes/models/bnn_heteroskedastic.py b/neurobayes/models/bnn_heteroskedastic.py index 9ed9dfa..3533d5c 100644 --- a/neurobayes/models/bnn_heteroskedastic.py +++ b/neurobayes/models/bnn_heteroskedastic.py @@ -54,7 +54,7 @@ def prior(name, shape): "nn", self.nn, input_shape=(1, *input_shape), prior=prior) # Pass inputs through a NN with the sampled parameters - mu, sig = net(X) + mu, sig = net(X, enable_dropout=False) # Register values with numpyro mu = numpyro.deterministic("mu", mu) sig = numpyro.deterministic("sig", sig) diff --git a/neurobayes/models/bnn_heteroskedastic_model.py b/neurobayes/models/bnn_heteroskedastic_model.py index 574466e..e241d45 100644 --- a/neurobayes/models/bnn_heteroskedastic_model.py +++ b/neurobayes/models/bnn_heteroskedastic_model.py @@ -36,7 +36,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, priors_sigma: float = 1.0 prior=(lambda name, shape: dist.Normal(0, priors_sigma))) # Pass inputs through a NN with the sampled parameters - mu = numpyro.deterministic("mu", net(X)) + mu = numpyro.deterministic("mu", net(X, enable_dropout=False)) # Sample noise variance according to the provided model var_params = self.variance_model_prior() diff --git a/neurobayes/models/partial_bnn_heteroskedastic.py b/neurobayes/models/partial_bnn_heteroskedastic.py index f2cf4b5..3c89695 100644 --- a/neurobayes/models/partial_bnn_heteroskedastic.py +++ b/neurobayes/models/partial_bnn_heteroskedastic.py @@ -90,7 +90,7 @@ def prior(name, shape): input_shape=(1, *current_input.shape[1:]), prior=prior ) - current_input = net(current_input) + current_input = net(current_input, enable_dropout=False) else: params = { "params": { @@ -100,7 +100,7 @@ def prior(name, shape): } } } - current_input = layer.apply(params, current_input) + current_input = layer.apply(params, current_input, enable_dropout=False) # Process head layers shared_output = current_input From 102467271cd76a11d78f71c2ea656d2a63a2d40e Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 28 Jan 2025 22:03:46 -0800 Subject: [PATCH 04/15] Exclude normalization layers from SWA --- neurobayes/flax_nets/deterministic_nn.py | 26 +++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index 7a8208b..44b9e57 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -158,15 +158,31 @@ def _store_params(self, params: Dict) -> None: self.params_history.append(params) def average_params(self) -> Dict: + """Average model parameters, excluding normalization layers""" if not self.params_history: return self.state.params - # Compute the element-wise average of all stored parameters - avg_params = jax.tree_util.tree_map( - lambda *param_trees: jnp.mean(jnp.stack(param_trees), axis=0), - *self.params_history + def should_average(path_tuple): + """Check if parameter should be averaged based on its path""" + path_str = '/'.join(str(p) for p in path_tuple) + skip_patterns = ['LayerNorm', 'BatchNorm', 'embedding/norm'] + return not any(pattern in path_str for pattern in skip_patterns) + + def average_leaves(*leaves, path=()): + """Average parameters if not in normalization layers""" + if should_average(path): + return jnp.mean(jnp.stack(leaves), axis=0) + else: + return leaves[0] # Keep original parameters + + # Apply averaging with path information + averaged_params = jax.tree_util.tree_map_with_path( + lambda path, *values: average_leaves(*values, path=path), + self.params_history[0], + *self.params_history[1:] ) - return avg_params + + return averaged_params def reset_swa(self): """Reset SWA collections""" From 1dc64de666b7da136580c24c5b48df38388fddcc Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 29 Jan 2025 19:23:14 -0800 Subject: [PATCH 05/15] Add basic transformer blocks --- neurobayes/flax_nets/transformer.py | 159 ++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 neurobayes/flax_nets/transformer.py diff --git a/neurobayes/flax_nets/transformer.py b/neurobayes/flax_nets/transformer.py new file mode 100644 index 0000000..b18dde8 --- /dev/null +++ b/neurobayes/flax_nets/transformer.py @@ -0,0 +1,159 @@ +from typing import List, Dict +import jax.numpy as jnp +import flax.linen as nn + + +class EmbedModule(nn.Module): + features: int + num_embeddings: int + layer_name: str = 'embed' + + @nn.compact + def __call__(self, x): + return nn.Embed( + num_embeddings=self.num_embeddings, + features=self.features, + name=self.layer_name + )(x) + + +class TransformerAttentionModule(nn.Module): + num_heads: int + qkv_features: int + dropout_rate: float = 0.1 + layer_name: str = 'attention' + block_idx: int = 0 + + @nn.compact + def __call__(self, x, enable_dropout: bool = True): + return nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + qkv_features=self.qkv_features, + dropout_rate=self.dropout_rate, + deterministic=not enable_dropout, + name=f"Block{self.block_idx}_{self.layer_name}" + )(x, x) + + +class TransformerMLPModule(nn.Module): + features: int + output_dim: int + activation: str = 'silu' + dropout_rate: float = 0.1 + layer_name: str = 'mlp' + block_idx: int = 0 + + @nn.compact + def __call__(self, x, enable_dropout: bool = True): + activation_fn = nn.silu if self.activation == 'silu' else nn.tanh + x = nn.Dense( + features=self.features, + name=f"Block{self.block_idx}_{self.layer_name}_dense1" + )(x) + x = activation_fn(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not enable_dropout)(x) + x = nn.Dense( + features=self.output_dim, + name=f"Block{self.block_idx}_{self.layer_name}_dense2" + )(x) + return x + + +class TransformerBlock(nn.Module): + d_model: int + nhead: int + dim_feedforward: int + activation: str = 'silu' + dropout_rate: float = 0.1 + block_idx: int = 0 + + @nn.compact + def __call__(self, x, enable_dropout: bool = True): + # Multi-head self-attention + attention = TransformerAttentionModule( + num_heads=self.nhead, + qkv_features=self.d_model, + dropout_rate=self.dropout_rate, + layer_name="Attention", + block_idx=self.block_idx + )(x, enable_dropout) + + # First residual and norm + x = x + attention + x = nn.LayerNorm(name=f"Block{self.block_idx}_LayerNorm1")(x) + + # MLP block + mlp = TransformerMLPModule( + features=self.dim_feedforward, + output_dim=self.d_model, + activation=self.activation, + dropout_rate=self.dropout_rate, + layer_name="MLP", + block_idx=self.block_idx + )(x, enable_dropout) + + # Second residual and norm + x = x + mlp + x = nn.LayerNorm(name=f"Block{self.block_idx}_LayerNorm2")(x) + + return x + + +class FlaxTransformer(nn.Module): + """Transformer model""" + vocab_size: int + d_model: int = 256 + nhead: int = 8 + num_layers: int = 4 + dim_feedforward: int = 1024 + activation: str = 'silu' + dropout_rate: float = 0.1 + classification: bool = False + max_seq_length: int = 1024 + + @nn.compact + def __call__(self, x, enable_dropout: bool = True): + + # Embedding layers + x = nn.Embed( + num_embeddings=self.vocab_size, + features=self.d_model, + name="TokenEmbed" + )(x) + + positions = jnp.arange(x.shape[1])[None, :] + position_embedding = nn.Embed( + num_embeddings=self.max_seq_length, + features=self.d_model, + name="PosEmbed" + )(positions) + x = x + position_embedding + + # Transformer blocks + for i in range(self.num_layers): + x = TransformerBlock( + d_model=self.d_model, + nhead=self.nhead, + dim_feedforward=self.dim_feedforward, + activation=self.activation, + dropout_rate=self.dropout_rate, + block_idx=i + )(x, enable_dropout=enable_dropout) + + # Pooling and final layers + activation_fn = nn.silu if self.activation == 'silu' else nn.tanh + x = jnp.mean(x, axis=1) + x = nn.Dense( + features=self.dim_feedforward, + name="FinalDense1" + )(x) + x = activation_fn(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not enable_dropout)(x) + x = nn.Dense( + features=1, + name="FinalDense2" + )(x) + if self.classification: + x = nn.softmax(x) + + return x.squeeze(-1) From 0c102621ca211a1f8e05c9342139751643673ea9 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 4 Feb 2025 09:39:15 -0800 Subject: [PATCH 06/15] Utility to extract layer configurations from a basic Transformer model --- neurobayes/flax_nets/configs.py | 131 ++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/neurobayes/flax_nets/configs.py b/neurobayes/flax_nets/configs.py index 2fe874d..499750e 100644 --- a/neurobayes/flax_nets/configs.py +++ b/neurobayes/flax_nets/configs.py @@ -2,6 +2,7 @@ import flax.linen as nn from .mlp import FlaxMLP, FlaxMLP2Head from .convnet import FlaxConvNet, FlaxConvNet2Head +from .transformer import FlaxTransformer def extract_mlp_configs( @@ -279,4 +280,134 @@ def extract_convnet2head_configs( "layer_name": layer_name }) + return configs + +def extract_transformer_configs( + net: FlaxTransformer, + probabilistic_layers: List[str] = None, + num_probabilistic_layers: int = None +) -> List[Dict]: + """Extract layer configurations from a Transformer model. + + Args: + net: The FlaxTransformer instance + probabilistic_layers: List of layer names to be treated as probabilistic. + Valid layer names include: + - "TokenEmbed" + - "PosEmbed" + - "Block{i}_Attention" for i in range(num_layers) + - "Block{i}_MLP_dense1" for i in range(num_layers) + - "Block{i}_MLP_dense2" for i in range(num_layers) + - "FinalDense1" + - "FinalDense2" + num_probabilistic_layers: Number of final layers to be probabilistic + (counting from the end of the network) + """ + if (probabilistic_layers is None) == (num_probabilistic_layers is None): + raise ValueError( + "Exactly one of 'probabilistic_layers' or 'num_probabilistic_layers' must be specified" + ) + + # Get all layer names in order + layer_names = [] + # First embeddings + layer_names.extend(["TokenEmbed", "PosEmbed"]) + # Then transformer blocks + for i in range(net.num_layers): + layer_names.extend([ + f"Block{i}_Attention", + f"Block{i}_MLP_dense1", + f"Block{i}_MLP_dense2" + ]) + # Finally output layers + layer_names.extend(["FinalDense1", "FinalDense2"]) + + # If using num_probabilistic_layers, create probabilistic_layers list + if num_probabilistic_layers is not None: + probabilistic_layers = layer_names[-num_probabilistic_layers:] + + activation_fn = nn.silu if net.activation == 'silu' else nn.tanh + + configs = [] + + # 1. Embedding layers + configs.extend([ + # Token embedding + { + "features": net.d_model, + "num_embeddings": net.vocab_size, + "activation": None, + "is_probabilistic": "TokenEmbed" in probabilistic_layers, + "layer_type": "embedding", + "layer_name": "TokenEmbed" + }, + # Position embedding + { + "features": net.d_model, + "num_embeddings": net.max_seq_length, + "activation": None, + "is_probabilistic": "PosEmbed" in probabilistic_layers, + "layer_type": "embedding", + "layer_name": "PosEmbed" + } + ]) + + # 2. Transformer blocks + for i in range(net.num_layers): + # a. Self-attention + configs.append({ + "features": net.d_model, + "activation": None, + "is_probabilistic": f"Block{i}_Attention" in probabilistic_layers, + "layer_type": "attention", + "layer_name": f"Block{i}_Attention", + "num_heads": net.nhead, + "qkv_features": net.d_model, + "dropout_rate": net.dropout_rate + }) + + # b. MLP part + configs.extend([ + # First dense layer + { + "features": net.dim_feedforward, + "activation": activation_fn, + "is_probabilistic": f"Block{i}_MLP_dense1" in probabilistic_layers, + "layer_type": "fc", + "layer_name": f"Block{i}_MLP_dense1", + "dropout_rate": net.dropout_rate + }, + # Second dense layer + { + "features": net.d_model, # Back to d_model size + "activation": None, + "is_probabilistic": f"Block{i}_MLP_dense2" in probabilistic_layers, + "layer_type": "fc", + "layer_name": f"Block{i}_MLP_dense2", + "dropout_rate": net.dropout_rate + } + ]) + + # 3. Final layers + configs.extend([ + # First final dense + { + "features": net.dim_feedforward, + "activation": activation_fn, + "is_probabilistic": "FinalDense1" in probabilistic_layers, + "layer_type": "fc", + "layer_name": "FinalDense1", + "dropout_rate": net.dropout_rate + }, + # Output layer + { + "features": 1, + "activation": None, # nn.softmax if net.classification else None, + "is_probabilistic": "FinalDense2" in probabilistic_layers, + "layer_type": "fc", + "layer_name": "FinalDense2", + "dropout_rate": net.dropout_rate + } + ]) + return configs \ No newline at end of file From 88fbb089e5975068707022b363bba1c33b7b6d14 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 4 Feb 2025 13:15:20 -0800 Subject: [PATCH 07/15] Update imports --- neurobayes/flax_nets/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neurobayes/flax_nets/__init__.py b/neurobayes/flax_nets/__init__.py index 32d1458..37de00f 100644 --- a/neurobayes/flax_nets/__init__.py +++ b/neurobayes/flax_nets/__init__.py @@ -1,5 +1,6 @@ from .convnet import * from .mlp import * +from .transformer import * from .deterministic_nn import * from .splitter import * from .configs import * From 802431572d114723becabfe12858f3565046e166 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 4 Feb 2025 13:54:01 -0800 Subject: [PATCH 08/15] Use correct dtype for transformer trainer init --- neurobayes/flax_nets/deterministic_nn.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index 44b9e57..f3a2bac 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -49,6 +49,10 @@ def __init__(self, input_shape = (input_shape,) if isinstance(input_shape, int) else input_shape self.model = architecture + + is_transformer = any(base.__name__.lower().find('transformer') >= 0 + for base in architecture.__mro__) + input_dtype = jnp.int32 if is_transformer else jnp.float32 if loss not in ['homoskedastic', 'heteroskedastic', 'classification']: raise ValueError("Select between 'homoskedastic', 'heteroskedastic', or 'classification' loss") @@ -56,7 +60,10 @@ def __init__(self, # Initialize model key = jax.random.PRNGKey(0) - params = self.model.init(key, jnp.ones((1, *input_shape)))['params'] + params = self.model.init( + key, + jnp.ones((1, *input_shape), dtype=input_dtype) + )['params'] # Default SWA configuration with all required parameters self.default_swa_config = { From b5c7ac58f65530e768d095358eef3916e18a7dcd Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 4 Feb 2025 13:56:50 -0800 Subject: [PATCH 09/15] Fix the dtype selection --- neurobayes/flax_nets/deterministic_nn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index f3a2bac..f5bff65 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -50,8 +50,7 @@ def __init__(self, input_shape = (input_shape,) if isinstance(input_shape, int) else input_shape self.model = architecture - is_transformer = any(base.__name__.lower().find('transformer') >= 0 - for base in architecture.__mro__) + is_transformer = 'transformer' in self.model.__class__.__name__.lower() input_dtype = jnp.int32 if is_transformer else jnp.float32 if loss not in ['homoskedastic', 'heteroskedastic', 'classification']: From 53ef5422bdd78e0876d020021fd34a25af8dfffb Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 4 Feb 2025 14:09:43 -0800 Subject: [PATCH 10/15] Fix target dimensionality --- neurobayes/flax_nets/deterministic_nn.py | 2 +- neurobayes/flax_nets/transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/neurobayes/flax_nets/deterministic_nn.py b/neurobayes/flax_nets/deterministic_nn.py index f5bff65..f5375f0 100644 --- a/neurobayes/flax_nets/deterministic_nn.py +++ b/neurobayes/flax_nets/deterministic_nn.py @@ -259,7 +259,7 @@ def set_data(self, X: jnp.ndarray, y: jnp.ndarray = None) -> jnp.ndarray: y = y.reshape(-1) y = jax.nn.one_hot(y, num_classes=self.model.target_dim) else: - y = y[:, None] if y.ndim < 2 else y # Regression + y = y[:, None] if y.ndim < 2 else y # Regression return X, y return X diff --git a/neurobayes/flax_nets/transformer.py b/neurobayes/flax_nets/transformer.py index b18dde8..f12737b 100644 --- a/neurobayes/flax_nets/transformer.py +++ b/neurobayes/flax_nets/transformer.py @@ -156,4 +156,4 @@ def __call__(self, x, enable_dropout: bool = True): if self.classification: x = nn.softmax(x) - return x.squeeze(-1) + return x#x.squeeze(-1) From 4bc91d2cf774a323c9cae7f679449c8edddf11db Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 4 Feb 2025 18:02:50 -0800 Subject: [PATCH 11/15] Add extract_transformer_configs to config_utils --- neurobayes/flax_nets/config_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/neurobayes/flax_nets/config_utils.py b/neurobayes/flax_nets/config_utils.py index 0544979..c0b1620 100644 --- a/neurobayes/flax_nets/config_utils.py +++ b/neurobayes/flax_nets/config_utils.py @@ -3,7 +3,8 @@ from .convnet import FlaxConvNet, FlaxConvNet2Head from .mlp import FlaxMLP, FlaxMLP2Head -from .configs import extract_mlp_configs, extract_convnet_configs, extract_mlp2head_configs, extract_convnet2head_configs +from .transformer import FlaxTransformer +from .configs import extract_mlp_configs, extract_convnet_configs, extract_mlp2head_configs, extract_convnet2head_configs, extract_transformer_configs @singledispatch def extract_configs(net, probabilistic_layers: List[str] = None, @@ -29,4 +30,9 @@ def _(net: FlaxConvNet, probabilistic_layers: List[str] = None, @extract_configs.register def _(net: FlaxConvNet2Head, probabilistic_layers: List[str] = None, num_probabilistic_layers: int = None) -> List[Dict]: - return extract_convnet2head_configs(net, probabilistic_layers, num_probabilistic_layers) \ No newline at end of file + return extract_convnet2head_configs(net, probabilistic_layers, num_probabilistic_layers) + +@extract_configs.register +def _(net: FlaxTransformer, probabilistic_layers: List[str] = None, + num_probabilistic_layers: int = None) -> List[Dict]: + return extract_transformer_configs(net, probabilistic_layers, num_probabilistic_layers) \ No newline at end of file From 107ac2a08f4df633d4c1d0fd32b443626d6bd054 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 5 Feb 2025 14:31:23 -0800 Subject: [PATCH 12/15] First attempt at Partial BTNN --- neurobayes/__init__.py | 1 + neurobayes/flax_nets/configs.py | 199 +++++++++++++--------------- neurobayes/flax_nets/transformer.py | 34 +++-- neurobayes/models/partial_btnn.py | 193 +++++++++++++++++++++++++++ neurobayes/utils/utils.py | 35 +++++ 5 files changed, 341 insertions(+), 121 deletions(-) create mode 100644 neurobayes/models/partial_btnn.py diff --git a/neurobayes/__init__.py b/neurobayes/__init__.py index e6a36ca..584ceeb 100644 --- a/neurobayes/__init__.py +++ b/neurobayes/__init__.py @@ -7,6 +7,7 @@ from .models.bnn_heteroskedastic import HeteroskedasticBNN from .models.bnn_heteroskedastic_model import VarianceModelHeteroskedasticBNN from .models.partial_bnn_heteroskedastic import HeteroskedasticPartialBNN +from .models.partial_btnn import PartialBTNN from .flax_nets.deterministic_nn import DeterministicNN from .flax_nets.convnet import FlaxConvNet, FlaxConvNet2Head, FlaxMLP, FlaxMLP2Head diff --git a/neurobayes/flax_nets/configs.py b/neurobayes/flax_nets/configs.py index 499750e..841fb9a 100644 --- a/neurobayes/flax_nets/configs.py +++ b/neurobayes/flax_nets/configs.py @@ -282,132 +282,111 @@ def extract_convnet2head_configs( return configs + def extract_transformer_configs( - net: FlaxTransformer, + transformer: FlaxTransformer, probabilistic_layers: List[str] = None, num_probabilistic_layers: int = None ) -> List[Dict]: - """Extract layer configurations from a Transformer model. - + """ + Extract layer configurations from a Transformer model. + Args: - net: The FlaxTransformer instance - probabilistic_layers: List of layer names to be treated as probabilistic. - Valid layer names include: - - "TokenEmbed" - - "PosEmbed" - - "Block{i}_Attention" for i in range(num_layers) - - "Block{i}_MLP_dense1" for i in range(num_layers) - - "Block{i}_MLP_dense2" for i in range(num_layers) - - "FinalDense1" - - "FinalDense2" + transformer: The FlaxTransformer instance + probabilistic_layers: List of layer names to be treated as probabilistic num_probabilistic_layers: Number of final layers to be probabilistic - (counting from the end of the network) """ if (probabilistic_layers is None) == (num_probabilistic_layers is None): raise ValueError( "Exactly one of 'probabilistic_layers' or 'num_probabilistic_layers' must be specified" ) - - # Get all layer names in order - layer_names = [] - # First embeddings - layer_names.extend(["TokenEmbed", "PosEmbed"]) - # Then transformer blocks - for i in range(net.num_layers): - layer_names.extend([ - f"Block{i}_Attention", - f"Block{i}_MLP_dense1", - f"Block{i}_MLP_dense2" - ]) - # Finally output layers - layer_names.extend(["FinalDense1", "FinalDense2"]) - - # If using num_probabilistic_layers, create probabilistic_layers list - if num_probabilistic_layers is not None: - probabilistic_layers = layer_names[-num_probabilistic_layers:] - - activation_fn = nn.silu if net.activation == 'silu' else nn.tanh - + + # Get activation function + activation_fn = nn.silu if transformer.activation == 'silu' else nn.tanh + configs = [] - - # 1. Embedding layers - configs.extend([ - # Token embedding - { - "features": net.d_model, - "num_embeddings": net.vocab_size, - "activation": None, - "is_probabilistic": "TokenEmbed" in probabilistic_layers, - "layer_type": "embedding", - "layer_name": "TokenEmbed" - }, - # Position embedding - { - "features": net.d_model, - "num_embeddings": net.max_seq_length, - "activation": None, - "is_probabilistic": "PosEmbed" in probabilistic_layers, - "layer_type": "embedding", - "layer_name": "PosEmbed" - } - ]) - - # 2. Transformer blocks - for i in range(net.num_layers): - # a. Self-attention + + # Embedding layer configs + configs.append({ + "features": transformer.d_model, + "num_embeddings": transformer.vocab_size, + "is_probabilistic": "TokenEmbed" in (probabilistic_layers or []), + "layer_type": "embedding", + "layer_name": "TokenEmbed" + }) + + configs.append({ + "features": transformer.d_model, + "num_embeddings": transformer.max_seq_length, + "is_probabilistic": "PosEmbed" in (probabilistic_layers or []), + "layer_type": "embedding", + "layer_name": "PosEmbed" + }) + + # For each transformer block + for i in range(transformer.num_layers): + # Attention config - note modified layer name to match TransformerAttentionModule configs.append({ - "features": net.d_model, - "activation": None, - "is_probabilistic": f"Block{i}_Attention" in probabilistic_layers, + "num_heads": transformer.nhead, + "qkv_features": transformer.d_model, + "dropout_rate": transformer.dropout_rate, + "is_probabilistic": f"Block{i}_Attention" in (probabilistic_layers or []), "layer_type": "attention", - "layer_name": f"Block{i}_Attention", - "num_heads": net.nhead, - "qkv_features": net.d_model, - "dropout_rate": net.dropout_rate + "layer_name": f"Block{i}_Attention" }) - - # b. MLP part - configs.extend([ - # First dense layer - { - "features": net.dim_feedforward, - "activation": activation_fn, - "is_probabilistic": f"Block{i}_MLP_dense1" in probabilistic_layers, - "layer_type": "fc", - "layer_name": f"Block{i}_MLP_dense1", - "dropout_rate": net.dropout_rate - }, - # Second dense layer - { - "features": net.d_model, # Back to d_model size - "activation": None, - "is_probabilistic": f"Block{i}_MLP_dense2" in probabilistic_layers, - "layer_type": "fc", - "layer_name": f"Block{i}_MLP_dense2", - "dropout_rate": net.dropout_rate - } - ]) - - # 3. Final layers - configs.extend([ - # First final dense - { - "features": net.dim_feedforward, + + # Layer norm 1 config - stays the same + configs.append({ + "is_probabilistic": f"Block{i}_LayerNorm1" in (probabilistic_layers or []), + "layer_type": "layernorm", + "layer_name": f"Block{i}_LayerNorm1" + }) + + # MLP configs - note modified layer names to match TransformerMLPModule + configs.append({ + "features": transformer.dim_feedforward, "activation": activation_fn, - "is_probabilistic": "FinalDense1" in probabilistic_layers, + "is_probabilistic": f"Block{i}_MLP_dense1" in (probabilistic_layers or []), "layer_type": "fc", - "layer_name": "FinalDense1", - "dropout_rate": net.dropout_rate - }, - # Output layer - { - "features": 1, - "activation": None, # nn.softmax if net.classification else None, - "is_probabilistic": "FinalDense2" in probabilistic_layers, + "layer_name": f"Block{i}_MLP_dense1" + }) + + configs.append({ + "features": transformer.d_model, + "activation": None, + "is_probabilistic": f"Block{i}_MLP_dense2" in (probabilistic_layers or []), "layer_type": "fc", - "layer_name": "FinalDense2", - "dropout_rate": net.dropout_rate - } - ]) - + "layer_name": f"Block{i}_MLP_dense2" + }) + + # Layer norm 2 config - stays the same + configs.append({ + "is_probabilistic": f"Block{i}_LayerNorm2" in (probabilistic_layers or []), + "layer_type": "layernorm", + "layer_name": f"Block{i}_LayerNorm2" + }) + + # Final layers + configs.append({ + "features": transformer.dim_feedforward, + "activation": activation_fn, + "is_probabilistic": "FinalDense1" in (probabilistic_layers or []), + "layer_type": "fc", + "layer_name": "FinalDense1" + }) + + configs.append({ + "features": 1, + "activation": nn.softmax if transformer.classification else None, + "is_probabilistic": "FinalDense2" in (probabilistic_layers or []), + "layer_type": "fc", + "layer_name": "FinalDense2" + }) + + # If using num_probabilistic_layers, update is_probabilistic flags + if num_probabilistic_layers is not None: + total_layers = len(configs) + for i, config in enumerate(configs): + config["is_probabilistic"] = i >= (total_layers - num_probabilistic_layers) + return configs \ No newline at end of file diff --git a/neurobayes/flax_nets/transformer.py b/neurobayes/flax_nets/transformer.py index f12737b..10ee2f3 100644 --- a/neurobayes/flax_nets/transformer.py +++ b/neurobayes/flax_nets/transformer.py @@ -17,6 +17,14 @@ def __call__(self, x): )(x) +class LayerNormModule(nn.Module): + layer_name: str = 'layernorm' + + @nn.compact + def __call__(self, x): + return nn.LayerNorm(name=self.layer_name)(x) + + class TransformerAttentionModule(nn.Module): num_heads: int qkv_features: int @@ -80,7 +88,9 @@ def __call__(self, x, enable_dropout: bool = True): # First residual and norm x = x + attention - x = nn.LayerNorm(name=f"Block{self.block_idx}_LayerNorm1")(x) + x = LayerNormModule( + layer_name=f"Block{self.block_idx}_LayerNorm1" + )(x) # MLP block mlp = TransformerMLPModule( @@ -94,7 +104,9 @@ def __call__(self, x, enable_dropout: bool = True): # Second residual and norm x = x + mlp - x = nn.LayerNorm(name=f"Block{self.block_idx}_LayerNorm2")(x) + x = LayerNormModule( + layer_name=f"Block{self.block_idx}_LayerNorm2" + )(x) return x @@ -113,21 +125,21 @@ class FlaxTransformer(nn.Module): @nn.compact def __call__(self, x, enable_dropout: bool = True): - # Embedding layers - x = nn.Embed( - num_embeddings=self.vocab_size, + token_embed = EmbedModule( features=self.d_model, - name="TokenEmbed" + num_embeddings=self.vocab_size, + layer_name="TokenEmbed" )(x) positions = jnp.arange(x.shape[1])[None, :] - position_embedding = nn.Embed( - num_embeddings=self.max_seq_length, + position_embedding = EmbedModule( features=self.d_model, - name="PosEmbed" + num_embeddings=self.max_seq_length, + layer_name="PosEmbed" )(positions) - x = x + position_embedding + + x = token_embed + position_embedding # Transformer blocks for i in range(self.num_layers): @@ -156,4 +168,4 @@ def __call__(self, x, enable_dropout: bool = True): if self.classification: x = nn.softmax(x) - return x#x.squeeze(-1) + return x \ No newline at end of file diff --git a/neurobayes/models/partial_btnn.py b/neurobayes/models/partial_btnn.py new file mode 100644 index 0000000..65d1187 --- /dev/null +++ b/neurobayes/models/partial_btnn.py @@ -0,0 +1,193 @@ +from typing import Dict, Optional, Type, Tuple, List +import jax.numpy as jnp +from jax.nn import softmax + +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.module import random_flax_module + +from .bnn import BNN +from ..flax_nets import FlaxTransformer, DeterministicNN +from ..flax_nets import MLPLayerModule, TransformerAttentionModule, EmbedModule, LayerNormModule +from ..flax_nets import extract_transformer_configs +from ..utils import flatten_transformer_params_dict + + +class PartialBTNN(BNN): + """ + Partially stochastic (Bayesian) Transformer network. + + Args: + transformer: FlaxTransformer architecture + deterministic_weights: Pre-trained deterministic weights. If not provided, + the transformer will be trained from scratch when running .fit() method + probabilistic_layer_names: Names of transformer modules to be treated probabilistically. + Valid names include: "TokenEmbed_0", "PosEmbed_0", "Block{i}_Attention", + "Block{i}_MLP_dense1", "Block{i}_MLP_dense2", "FinalDense1", "FinalDense2" + num_probabilistic_layers: Alternative to probabilistic_layer_names. + Number of final layers to be treated as probabilistic + num_classes: Number of classes for classification task. + If None, the model performs regression. Defaults to None. + noise_prior: Custom prior for observational noise distribution + """ + + def __init__(self, + transformer: Type[FlaxTransformer], + deterministic_weights: Optional[Dict[str, jnp.ndarray]] = None, + probabilistic_layer_names: List[str] = None, + num_probabilistic_layers: int = None, + num_classes: Optional[int] = None, + noise_prior: Optional[dist.Distribution] = None + ) -> None: + super().__init__(None, num_classes, noise_prior) + + self.deterministic_nn = transformer + self.deterministic_weights = deterministic_weights + + # Extract configurations + self.layer_configs = extract_transformer_configs( + transformer, probabilistic_layer_names, num_probabilistic_layers) + + def model(self, X: jnp.ndarray, y: jnp.ndarray = None, priors_sigma: float = 1.0, **kwargs) -> None: + net = self.deterministic_nn + pretrained_priors = flatten_transformer_params_dict(self.deterministic_weights) + + def prior(name, shape): + param_path = name.split('.') + layer_name = param_path[0] + + if len(param_path) == 3: # Attention parameters + component = param_path[1] # 'query', 'key', or 'value' + param_type = param_path[2] # 'kernel' or 'bias' + return dist.Normal(pretrained_priors[layer_name][component][param_type], priors_sigma) + else: # Other parameters + param_type = param_path[-1] + return dist.Normal(pretrained_priors[layer_name][param_type], priors_sigma) + + current_input = X + positions = jnp.arange(X.shape[1])[None, :] + token_embedding = None + pos_embedding = None + + for config in self.layer_configs: + layer_name = config['layer_name'] + layer_type = config['layer_type'] + + # Embeddings have special handling for inputs and outputs + if layer_type == "embedding": + layer = EmbedModule( + features=config['features'], + num_embeddings=config['num_embeddings'], + layer_name=layer_name + ) + input_data = positions if layer_name == 'PosEmbed' else current_input + + if config['is_probabilistic']: + net = random_flax_module(layer_name, layer, + input_shape=(1, *input_data.shape[1:]), prior=prior) + embedding = net(input_data) + else: + params = {"params": {layer_name: pretrained_priors[layer_name]}} + embedding = layer.apply(params, input_data) + + if layer_name == 'TokenEmbed': + token_embedding = embedding + else: # PosEmbed + pos_embedding = embedding + current_input = token_embedding + pos_embedding + + # Layer norms are always deterministic + elif layer_type == "layernorm": + layer = LayerNormModule(layer_name=layer_name) + params = {"params": {layer_name: pretrained_priors[layer_name]}} + current_input = layer.apply(params, current_input) + + # Attention needs block_idx for naming + elif layer_type == "attention": + block_idx = int(layer_name.split('_')[0][5:]) + layer = TransformerAttentionModule( + num_heads=config['num_heads'], + qkv_features=config['qkv_features'], + layer_name="Attention", + block_idx=block_idx + ) + if config['is_probabilistic']: + net = random_flax_module(layer_name, layer, + input_shape=(1, *current_input.shape[1:]), prior=prior) + current_input = net(current_input, enable_dropout=False) + else: + params = {"params": {f"Block{block_idx}_Attention": pretrained_priors[layer_name]}} + current_input = layer.apply(params, current_input, enable_dropout=False) + + # MLP/Dense layers + else: + layer = MLPLayerModule( + features=config['features'], + activation=config.get('activation'), + layer_name=layer_name + ) + if config['is_probabilistic']: + net = random_flax_module(layer_name, layer, + input_shape=(1, *current_input.shape[1:]), prior=prior) + current_input = net(current_input, enable_dropout=False) + else: + if layer_name.startswith('Block'): + block_idx = int(layer_name.split('_')[0][5:]) + params = {"params": {f"Block{block_idx}_MLP_{layer_name.split('_')[-1]}": pretrained_priors[layer_name]}} + else: + params = {"params": {layer_name: pretrained_priors[layer_name]}} + current_input = layer.apply(params, current_input, enable_dropout=False) + + # Output processing + current_input = jnp.mean(current_input, axis=1) + + if self.is_regression: + mu = numpyro.deterministic("mu", current_input) + sig = numpyro.sample("sig", self.noise_prior) + numpyro.sample("y", dist.Normal(mu, sig), obs=y) + else: + probs = numpyro.deterministic("probs", softmax(current_input, axis=-1)) + numpyro.sample("y", dist.Categorical(probs=probs), obs=y) + + def fit(self, X: jnp.ndarray, y: jnp.ndarray, + num_warmup: int = 2000, num_samples: int = 2000, + num_chains: int = 1, chain_method: str = 'sequential', + sgd_epochs: Optional[int] = None, sgd_lr: Optional[float] = 0.01, + sgd_batch_size: Optional[int] = None, swa_config: Optional[Dict] = None, + map_sigma: float = 1.0, priors_sigma: float = 1.0, + progress_bar: bool = True, device: str = None, + rng_key: Optional[jnp.array] = None, + extra_fields: Optional[Tuple[str, ...]] = (), + **kwargs + ) -> None: + """ + Fit the partially Bayesian transformer. + + Args: + X (jnp.ndarray): Input sequences of shape (batch_size, seq_length). + For other parameters, see BNN.fit() documentation. + """ + + if not self.deterministic_weights: + print("Training deterministic transformer...") + X, y = self.set_data(X, y) + det_nn = DeterministicNN( + self.transformer, + input_shape=X.shape[1:], + loss='homoskedastic' if self.is_regression else 'classification', + learning_rate=sgd_lr, + swa_config=swa_config, + sigma=map_sigma + ) + det_nn.train( + X, y, + 500 if sgd_epochs is None else sgd_epochs, + sgd_batch_size + ) + self.deterministic_weights = det_nn.state.params + print("Training partially Bayesian transformer") + + super().fit( + X, y, num_warmup, num_samples, num_chains, chain_method, + priors_sigma, progress_bar, device, rng_key, extra_fields, **kwargs + ) \ No newline at end of file diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index dddc0d9..4d4b3d3 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -246,6 +246,41 @@ def flatten_params_dict(params_dict: Dict[str, Any]) -> Dict[str, Any]: return flattened +def flatten_transformer_params_dict(params_dict): + """ + Properly flatten transformer parameter dictionary to match our layer naming scheme. + """ + flattened = {} + params = params_dict + + # Map EmbedModule_0/1 to TokenEmbed/PosEmbed + flattened['TokenEmbed'] = params['EmbedModule_0']['TokenEmbed'] + flattened['PosEmbed'] = params['EmbedModule_1']['PosEmbed'] + + # Handle transformer blocks + for i in range(len([k for k in params.keys() if k.startswith('TransformerBlock')])): + block_params = params[f'TransformerBlock_{i}'] + + # LayerNorm parameters are in LayerNormModule_0/1 + flattened[f'Block{i}_LayerNorm1'] = block_params[f'LayerNormModule_0'][f'Block{i}_LayerNorm1'] + flattened[f'Block{i}_LayerNorm2'] = block_params[f'LayerNormModule_1'][f'Block{i}_LayerNorm2'] + + # Attention parameters + flattened[f'Block{i}_Attention'] = block_params[f'TransformerAttentionModule_0'][f'Block{i}_Attention'] + + # MLP parameters + flattened[f'Block{i}_MLP_dense1'] = block_params[f'TransformerMLPModule_0'][f'Block{i}_MLP_dense1'] + flattened[f'Block{i}_MLP_dense2'] = block_params[f'TransformerMLPModule_0'][f'Block{i}_MLP_dense2'] + + # Final dense layers + if 'FinalDense1' in params: + flattened['FinalDense1'] = params['FinalDense1'] + if 'FinalDense2' in params: + flattened['FinalDense2'] = params['FinalDense2'] + + return flattened + + def set_fn(func: Callable) -> Callable: """ Transforms a given deterministic function to use a params dictionary From bced2efe4ab3c1d8f4a5785fd57a5a684b4347ab Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 5 Feb 2025 15:35:40 -0800 Subject: [PATCH 13/15] Ensure residual connections are in PartialBTNN --- neurobayes/models/partial_btnn.py | 32 ++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/neurobayes/models/partial_btnn.py b/neurobayes/models/partial_btnn.py index 65d1187..628475c 100644 --- a/neurobayes/models/partial_btnn.py +++ b/neurobayes/models/partial_btnn.py @@ -73,7 +73,6 @@ def prior(name, shape): layer_name = config['layer_name'] layer_type = config['layer_type'] - # Embeddings have special handling for inputs and outputs if layer_type == "embedding": layer = EmbedModule( features=config['features'], @@ -95,19 +94,16 @@ def prior(name, shape): else: # PosEmbed pos_embedding = embedding current_input = token_embedding + pos_embedding - - # Layer norms are always deterministic - elif layer_type == "layernorm": - layer = LayerNormModule(layer_name=layer_name) - params = {"params": {layer_name: pretrained_priors[layer_name]}} - current_input = layer.apply(params, current_input) - # Attention needs block_idx for naming elif layer_type == "attention": + # Save input for residual + residual = current_input + block_idx = int(layer_name.split('_')[0][5:]) layer = TransformerAttentionModule( num_heads=config['num_heads'], qkv_features=config['qkv_features'], + dropout_rate=config.get('dropout_rate', 0.1), layer_name="Attention", block_idx=block_idx ) @@ -119,8 +115,19 @@ def prior(name, shape): params = {"params": {f"Block{block_idx}_Attention": pretrained_priors[layer_name]}} current_input = layer.apply(params, current_input, enable_dropout=False) - # MLP/Dense layers - else: + # Add residual after attention + current_input = current_input + residual + + elif layer_type == "layernorm": + layer = LayerNormModule(layer_name=layer_name) + params = {"params": {layer_name: pretrained_priors[layer_name]}} + current_input = layer.apply(params, current_input) + + # Save residual after first layer norm in each block + if layer_name.endswith('LayerNorm1'): + residual = current_input + + else: # fc layers layer = MLPLayerModule( features=config['features'], activation=config.get('activation'), @@ -137,8 +144,11 @@ def prior(name, shape): else: params = {"params": {layer_name: pretrained_priors[layer_name]}} current_input = layer.apply(params, current_input, enable_dropout=False) + + # Add residual after second dense layer in each block + if layer_name.endswith('dense2'): + current_input = current_input + residual - # Output processing current_input = jnp.mean(current_input, axis=1) if self.is_regression: From 4e14d4499a71c3e3a64dd8eeb1a840cc4b6a7543 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 5 Feb 2025 20:01:44 -0800 Subject: [PATCH 14/15] Ensure inputs to EmbedModule are integers --- neurobayes/flax_nets/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurobayes/flax_nets/transformer.py b/neurobayes/flax_nets/transformer.py index 10ee2f3..838b395 100644 --- a/neurobayes/flax_nets/transformer.py +++ b/neurobayes/flax_nets/transformer.py @@ -10,12 +10,12 @@ class EmbedModule(nn.Module): @nn.compact def __call__(self, x): + x = x.astype(jnp.int32) return nn.Embed( num_embeddings=self.num_embeddings, features=self.features, name=self.layer_name )(x) - class LayerNormModule(nn.Module): layer_name: str = 'layernorm' From 41991c673bc036a343b381da1332640bb1253004 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Thu, 6 Feb 2025 09:42:52 -0800 Subject: [PATCH 15/15] Clean up --- neurobayes/flax_nets/configs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/neurobayes/flax_nets/configs.py b/neurobayes/flax_nets/configs.py index 841fb9a..88ccbed 100644 --- a/neurobayes/flax_nets/configs.py +++ b/neurobayes/flax_nets/configs.py @@ -325,7 +325,6 @@ def extract_transformer_configs( # For each transformer block for i in range(transformer.num_layers): - # Attention config - note modified layer name to match TransformerAttentionModule configs.append({ "num_heads": transformer.nhead, "qkv_features": transformer.d_model, @@ -335,14 +334,12 @@ def extract_transformer_configs( "layer_name": f"Block{i}_Attention" }) - # Layer norm 1 config - stays the same configs.append({ "is_probabilistic": f"Block{i}_LayerNorm1" in (probabilistic_layers or []), "layer_type": "layernorm", "layer_name": f"Block{i}_LayerNorm1" }) - # MLP configs - note modified layer names to match TransformerMLPModule configs.append({ "features": transformer.dim_feedforward, "activation": activation_fn, @@ -359,7 +356,6 @@ def extract_transformer_configs( "layer_name": f"Block{i}_MLP_dense2" }) - # Layer norm 2 config - stays the same configs.append({ "is_probabilistic": f"Block{i}_LayerNorm2" in (probabilistic_layers or []), "layer_type": "layernorm",