Skip to content

Commit

Permalink
Merge pull request #33 from ziatdinovmax/dropout
Browse files Browse the repository at this point in the history
Add basic Transformer (deterministic and partially Bayesian)
  • Loading branch information
ziatdinovmax authored Feb 6, 2025
2 parents 6c808bc + 41991c6 commit f0a6f10
Show file tree
Hide file tree
Showing 15 changed files with 675 additions and 68 deletions.
1 change: 1 addition & 0 deletions neurobayes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .models.bnn_heteroskedastic import HeteroskedasticBNN
from .models.bnn_heteroskedastic_model import VarianceModelHeteroskedasticBNN
from .models.partial_bnn_heteroskedastic import HeteroskedasticPartialBNN
from .models.partial_btnn import PartialBTNN
from .flax_nets.deterministic_nn import DeterministicNN
from .flax_nets.convnet import FlaxConvNet, FlaxConvNet2Head, FlaxMLP, FlaxMLP2Head

Expand Down
1 change: 1 addition & 0 deletions neurobayes/flax_nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .convnet import *
from .mlp import *
from .transformer import *
from .deterministic_nn import *
from .splitter import *
from .configs import *
Expand Down
10 changes: 8 additions & 2 deletions neurobayes/flax_nets/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from .convnet import FlaxConvNet, FlaxConvNet2Head
from .mlp import FlaxMLP, FlaxMLP2Head
from .configs import extract_mlp_configs, extract_convnet_configs, extract_mlp2head_configs, extract_convnet2head_configs
from .transformer import FlaxTransformer
from .configs import extract_mlp_configs, extract_convnet_configs, extract_mlp2head_configs, extract_convnet2head_configs, extract_transformer_configs

@singledispatch
def extract_configs(net, probabilistic_layers: List[str] = None,
Expand All @@ -29,4 +30,9 @@ def _(net: FlaxConvNet, probabilistic_layers: List[str] = None,
@extract_configs.register
def _(net: FlaxConvNet2Head, probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None) -> List[Dict]:
return extract_convnet2head_configs(net, probabilistic_layers, num_probabilistic_layers)
return extract_convnet2head_configs(net, probabilistic_layers, num_probabilistic_layers)

@extract_configs.register
def _(net: FlaxTransformer, probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None) -> List[Dict]:
return extract_transformer_configs(net, probabilistic_layers, num_probabilistic_layers)
106 changes: 106 additions & 0 deletions neurobayes/flax_nets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import flax.linen as nn
from .mlp import FlaxMLP, FlaxMLP2Head
from .convnet import FlaxConvNet, FlaxConvNet2Head
from .transformer import FlaxTransformer


def extract_mlp_configs(
Expand Down Expand Up @@ -279,4 +280,109 @@ def extract_convnet2head_configs(
"layer_name": layer_name
})

return configs


def extract_transformer_configs(
transformer: FlaxTransformer,
probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None
) -> List[Dict]:
"""
Extract layer configurations from a Transformer model.
Args:
transformer: The FlaxTransformer instance
probabilistic_layers: List of layer names to be treated as probabilistic
num_probabilistic_layers: Number of final layers to be probabilistic
"""
if (probabilistic_layers is None) == (num_probabilistic_layers is None):
raise ValueError(
"Exactly one of 'probabilistic_layers' or 'num_probabilistic_layers' must be specified"
)

# Get activation function
activation_fn = nn.silu if transformer.activation == 'silu' else nn.tanh

configs = []

# Embedding layer configs
configs.append({
"features": transformer.d_model,
"num_embeddings": transformer.vocab_size,
"is_probabilistic": "TokenEmbed" in (probabilistic_layers or []),
"layer_type": "embedding",
"layer_name": "TokenEmbed"
})

configs.append({
"features": transformer.d_model,
"num_embeddings": transformer.max_seq_length,
"is_probabilistic": "PosEmbed" in (probabilistic_layers or []),
"layer_type": "embedding",
"layer_name": "PosEmbed"
})

# For each transformer block
for i in range(transformer.num_layers):
configs.append({
"num_heads": transformer.nhead,
"qkv_features": transformer.d_model,
"dropout_rate": transformer.dropout_rate,
"is_probabilistic": f"Block{i}_Attention" in (probabilistic_layers or []),
"layer_type": "attention",
"layer_name": f"Block{i}_Attention"
})

configs.append({
"is_probabilistic": f"Block{i}_LayerNorm1" in (probabilistic_layers or []),
"layer_type": "layernorm",
"layer_name": f"Block{i}_LayerNorm1"
})

configs.append({
"features": transformer.dim_feedforward,
"activation": activation_fn,
"is_probabilistic": f"Block{i}_MLP_dense1" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": f"Block{i}_MLP_dense1"
})

configs.append({
"features": transformer.d_model,
"activation": None,
"is_probabilistic": f"Block{i}_MLP_dense2" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": f"Block{i}_MLP_dense2"
})

configs.append({
"is_probabilistic": f"Block{i}_LayerNorm2" in (probabilistic_layers or []),
"layer_type": "layernorm",
"layer_name": f"Block{i}_LayerNorm2"
})

# Final layers
configs.append({
"features": transformer.dim_feedforward,
"activation": activation_fn,
"is_probabilistic": "FinalDense1" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": "FinalDense1"
})

configs.append({
"features": 1,
"activation": nn.softmax if transformer.classification else None,
"is_probabilistic": "FinalDense2" in (probabilistic_layers or []),
"layer_type": "fc",
"layer_name": "FinalDense2"
})

# If using num_probabilistic_layers, update is_probabilistic flags
if num_probabilistic_layers is not None:
total_layers = len(configs)
for i, config in enumerate(configs):
config["is_probabilistic"] = i >= (total_layers - num_probabilistic_layers)

return configs
36 changes: 28 additions & 8 deletions neurobayes/flax_nets/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ class ConvLayerModule(nn.Module):
input_dim: int
kernel_size: Union[int, Tuple[int, ...]]
activation: Any = None
dropout: float = 0.0
layer_name: str = None

@nn.compact
def __call__(self, x):
def __call__(self, x, enable_dropout: bool = True):
conv, pool = get_conv_and_pool_ops(self.input_dim, self.kernel_size)
x = conv(features=self.features, name=self.layer_name)(x)
if self.activation is not None:
x = self.activation(x)
if self.dropout > 0:
x = nn.Dropout(rate=self.dropout)(x, deterministic=not enable_dropout)
x = pool(x)
return x

Expand All @@ -28,10 +31,14 @@ class FlaxConvNet(nn.Module):
target_dim: int
activation: str = 'tanh'
kernel_size: Union[int, Tuple[int, ...]] = 3
classification: bool = False # Explicit flag for classification tasks
conv_dropout: float = 0.0
hidden_dropout: float = 0.0
output_dropout: float = 0.0
classification: bool = False

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
def __call__(self, x: jnp.ndarray, enable_dropout: bool = True
) -> jnp.ndarray:
activation_fn = nn.tanh if self.activation == 'tanh' else nn.silu

# Convolutional layers
Expand All @@ -41,9 +48,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
input_dim=self.input_dim,
kernel_size=self.kernel_size,
activation=activation_fn,
dropout=self.conv_dropout,
layer_name=f"Conv{i}"
)
x = conv_layer(x)
x = conv_layer(x, enable_dropout=enable_dropout)

# Flatten the output for the fully connected layers
x = x.reshape((x.shape[0], -1))
Expand All @@ -53,7 +61,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hidden_dims=self.fc_layers,
target_dim=self.target_dim,
activation=self.activation,
classification=self.classification)(x)
hidden_dropout=self.hidden_dropout,
output_dropout=self.output_dropout,
classification=self.classification
)(x, enable_dropout=enable_dropout)

return x

Expand All @@ -64,10 +75,15 @@ class FlaxConvNet2Head(nn.Module):
fc_layers: Sequence[int]
target_dim: int
activation: str = 'tanh'
conv_dropout: float = 0.0
hidden_dropout: float = 0.0
output_dropout: float = 0.0
kernel_size: Union[int, Tuple[int, ...]] = 3

@nn.compact
def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
def __call__(self, x: jnp.ndarray, enable_dropout: bool = True
) -> Tuple[jnp.ndarray, jnp.ndarray]:

activation_fn = nn.tanh if self.activation == 'tanh' else nn.silu

# Convolutional layers
Expand All @@ -77,17 +93,21 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
input_dim=self.input_dim,
kernel_size=self.kernel_size,
activation=activation_fn,
dropout=self.conv_dropout,
layer_name=f"Conv{i}"
)
x = conv_layer(x)
x = conv_layer(x, enable_dropout=enable_dropout)

# Flatten the output for the fully connected layers
x = x.reshape((x.shape[0], -1))

mean, var = FlaxMLP2Head(
hidden_dims=self.fc_layers,
target_dim=self.target_dim,
activation=self.activation)(x)
activation=self.activation,
hidden_dropout=self.hidden_dropout,
output_dropout=self.output_dropout
)(x, enable_dropout=enable_dropout)

return mean, var

Expand Down
Loading

0 comments on commit f0a6f10

Please sign in to comment.