Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic Transformer (deterministic and partially Bayesian) #33

Merged
merged 15 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neurobayes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .models.bnn_heteroskedastic import HeteroskedasticBNN
from .models.bnn_heteroskedastic_model import VarianceModelHeteroskedasticBNN
from .models.partial_bnn_heteroskedastic import HeteroskedasticPartialBNN
from .models.partial_btnn import PartialBTNN
from .flax_nets.deterministic_nn import DeterministicNN
from .flax_nets.convnet import FlaxConvNet, FlaxConvNet2Head, FlaxMLP, FlaxMLP2Head

Expand Down
1 change: 1 addition & 0 deletions neurobayes/flax_nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .convnet import *
from .mlp import *
from .transformer import *
from .deterministic_nn import *
from .splitter import *
from .configs import *
Expand Down
10 changes: 8 additions & 2 deletions neurobayes/flax_nets/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from .convnet import FlaxConvNet, FlaxConvNet2Head
from .mlp import FlaxMLP, FlaxMLP2Head
from .configs import extract_mlp_configs, extract_convnet_configs, extract_mlp2head_configs, extract_convnet2head_configs
from .transformer import FlaxTransformer
from .configs import extract_mlp_configs, extract_convnet_configs, extract_mlp2head_configs, extract_convnet2head_configs, extract_transformer_configs

@singledispatch
def extract_configs(net, probabilistic_layers: List[str] = None,
Expand All @@ -29,4 +30,9 @@ def _(net: FlaxConvNet, probabilistic_layers: List[str] = None,
@extract_configs.register
def _(net: FlaxConvNet2Head, probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None) -> List[Dict]:
return extract_convnet2head_configs(net, probabilistic_layers, num_probabilistic_layers)
return extract_convnet2head_configs(net, probabilistic_layers, num_probabilistic_layers)

@extract_configs.register
def _(net: FlaxTransformer, probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None) -> List[Dict]:
return extract_transformer_configs(net, probabilistic_layers, num_probabilistic_layers)
106 changes: 106 additions & 0 deletions neurobayes/flax_nets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import flax.linen as nn
from .mlp import FlaxMLP, FlaxMLP2Head
from .convnet import FlaxConvNet, FlaxConvNet2Head
from .transformer import FlaxTransformer


def extract_mlp_configs(
Expand Down Expand Up @@ -279,4 +280,109 @@ def extract_convnet2head_configs(
"layer_name": layer_name
})

return configs


def extract_transformer_configs(
transformer: FlaxTransformer,
probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None
) -> List[Dict]:
"""
Extract layer configurations from a Transformer model.

Args:
transformer: The FlaxTransformer instance
probabilistic_layers: List of layer names to be treated as probabilistic
num_probabilistic_layers: Number of final layers to be probabilistic
"""
if (probabilistic_layers is None) == (num_probabilistic_layers is None):
raise ValueError(
"Exactly one of 'probabilistic_layers' or 'num_probabilistic_layers' must be specified"
)

# Get activation function
activation_fn = nn.silu if transformer.activation == 'silu' else nn.tanh

configs = []

# Embedding layer configs
configs.append({
"features": transformer.d_model,
"num_embeddings": transformer.vocab_size,
"is_probabilistic": "TokenEmbed" in (probabilistic_layers or []),
"layer_type": "embedding",
"layer_name": "TokenEmbed"
})

configs.append({
"features": transformer.d_model,
"num_embeddings": transformer.max_seq_length,
"is_probabilistic": "PosEmbed" in (probabilistic_layers or []),
"layer_type": "embedding",
"layer_name": "PosEmbed"
})

# For each transformer block
for i in range(transformer.num_layers):
configs.append({
"num_heads": transformer.nhead,
"qkv_features": transformer.d_model,
"dropout_rate": transformer.dropout_rate,
"is_probabilistic": f"Block{i}_Attention" in (probabilistic_layers or []),
"layer_type": "attention",
"layer_name": f"Block{i}_Attention"
})

configs.append({
"is_probabilistic": f"Block{i}_LayerNorm1" in (probabilistic_layers or []),
"layer_type": "layernorm",
"layer_name": f"Block{i}_LayerNorm1"
})

configs.append({
"features": transformer.dim_feedforward,
"activation": activation_fn,
"is_probabilistic": f"Block{i}_MLP_dense1" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": f"Block{i}_MLP_dense1"
})

configs.append({
"features": transformer.d_model,
"activation": None,
"is_probabilistic": f"Block{i}_MLP_dense2" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": f"Block{i}_MLP_dense2"
})

configs.append({
"is_probabilistic": f"Block{i}_LayerNorm2" in (probabilistic_layers or []),
"layer_type": "layernorm",
"layer_name": f"Block{i}_LayerNorm2"
})

# Final layers
configs.append({
"features": transformer.dim_feedforward,
"activation": activation_fn,
"is_probabilistic": "FinalDense1" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": "FinalDense1"
})

configs.append({
"features": 1,
"activation": nn.softmax if transformer.classification else None,
"is_probabilistic": "FinalDense2" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": "FinalDense2"
})

# If using num_probabilistic_layers, update is_probabilistic flags
if num_probabilistic_layers is not None:
total_layers = len(configs)
for i, config in enumerate(configs):
config["is_probabilistic"] = i >= (total_layers - num_probabilistic_layers)

return configs
36 changes: 28 additions & 8 deletions neurobayes/flax_nets/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ class ConvLayerModule(nn.Module):
input_dim: int
kernel_size: Union[int, Tuple[int, ...]]
activation: Any = None
dropout: float = 0.0
layer_name: str = None

@nn.compact
def __call__(self, x):
def __call__(self, x, enable_dropout: bool = True):
conv, pool = get_conv_and_pool_ops(self.input_dim, self.kernel_size)
x = conv(features=self.features, name=self.layer_name)(x)
if self.activation is not None:
x = self.activation(x)
if self.dropout > 0:
x = nn.Dropout(rate=self.dropout)(x, deterministic=not enable_dropout)
x = pool(x)
return x

Expand All @@ -28,10 +31,14 @@ class FlaxConvNet(nn.Module):
target_dim: int
activation: str = 'tanh'
kernel_size: Union[int, Tuple[int, ...]] = 3
classification: bool = False # Explicit flag for classification tasks
conv_dropout: float = 0.0
hidden_dropout: float = 0.0
output_dropout: float = 0.0
classification: bool = False

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
def __call__(self, x: jnp.ndarray, enable_dropout: bool = True
) -> jnp.ndarray:
activation_fn = nn.tanh if self.activation == 'tanh' else nn.silu

# Convolutional layers
Expand All @@ -41,9 +48,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
input_dim=self.input_dim,
kernel_size=self.kernel_size,
activation=activation_fn,
dropout=self.conv_dropout,
layer_name=f"Conv{i}"
)
x = conv_layer(x)
x = conv_layer(x, enable_dropout=enable_dropout)

# Flatten the output for the fully connected layers
x = x.reshape((x.shape[0], -1))
Expand All @@ -53,7 +61,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hidden_dims=self.fc_layers,
target_dim=self.target_dim,
activation=self.activation,
classification=self.classification)(x)
hidden_dropout=self.hidden_dropout,
output_dropout=self.output_dropout,
classification=self.classification
)(x, enable_dropout=enable_dropout)

return x

Expand All @@ -64,10 +75,15 @@ class FlaxConvNet2Head(nn.Module):
fc_layers: Sequence[int]
target_dim: int
activation: str = 'tanh'
conv_dropout: float = 0.0
hidden_dropout: float = 0.0
output_dropout: float = 0.0
kernel_size: Union[int, Tuple[int, ...]] = 3

@nn.compact
def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
def __call__(self, x: jnp.ndarray, enable_dropout: bool = True
) -> Tuple[jnp.ndarray, jnp.ndarray]:

activation_fn = nn.tanh if self.activation == 'tanh' else nn.silu

# Convolutional layers
Expand All @@ -77,17 +93,21 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
input_dim=self.input_dim,
kernel_size=self.kernel_size,
activation=activation_fn,
dropout=self.conv_dropout,
layer_name=f"Conv{i}"
)
x = conv_layer(x)
x = conv_layer(x, enable_dropout=enable_dropout)

# Flatten the output for the fully connected layers
x = x.reshape((x.shape[0], -1))

mean, var = FlaxMLP2Head(
hidden_dims=self.fc_layers,
target_dim=self.target_dim,
activation=self.activation)(x)
activation=self.activation,
hidden_dropout=self.hidden_dropout,
output_dropout=self.output_dropout
)(x, enable_dropout=enable_dropout)

return mean, var

Expand Down
Loading