Skip to content

Commit

Permalink
Utility to extract layer configurations from a basic Transformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 4, 2025
1 parent 1dc64de commit 0c10262
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions neurobayes/flax_nets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import flax.linen as nn
from .mlp import FlaxMLP, FlaxMLP2Head
from .convnet import FlaxConvNet, FlaxConvNet2Head
from .transformer import FlaxTransformer


def extract_mlp_configs(
Expand Down Expand Up @@ -279,4 +280,134 @@ def extract_convnet2head_configs(
"layer_name": layer_name
})

return configs

def extract_transformer_configs(
net: FlaxTransformer,
probabilistic_layers: List[str] = None,
num_probabilistic_layers: int = None
) -> List[Dict]:
"""Extract layer configurations from a Transformer model.
Args:
net: The FlaxTransformer instance
probabilistic_layers: List of layer names to be treated as probabilistic.
Valid layer names include:
- "TokenEmbed"
- "PosEmbed"
- "Block{i}_Attention" for i in range(num_layers)
- "Block{i}_MLP_dense1" for i in range(num_layers)
- "Block{i}_MLP_dense2" for i in range(num_layers)
- "FinalDense1"
- "FinalDense2"
num_probabilistic_layers: Number of final layers to be probabilistic
(counting from the end of the network)
"""
if (probabilistic_layers is None) == (num_probabilistic_layers is None):
raise ValueError(
"Exactly one of 'probabilistic_layers' or 'num_probabilistic_layers' must be specified"
)

# Get all layer names in order
layer_names = []
# First embeddings
layer_names.extend(["TokenEmbed", "PosEmbed"])
# Then transformer blocks
for i in range(net.num_layers):
layer_names.extend([
f"Block{i}_Attention",
f"Block{i}_MLP_dense1",
f"Block{i}_MLP_dense2"
])
# Finally output layers
layer_names.extend(["FinalDense1", "FinalDense2"])

# If using num_probabilistic_layers, create probabilistic_layers list
if num_probabilistic_layers is not None:
probabilistic_layers = layer_names[-num_probabilistic_layers:]

activation_fn = nn.silu if net.activation == 'silu' else nn.tanh

configs = []

# 1. Embedding layers
configs.extend([
# Token embedding
{
"features": net.d_model,
"num_embeddings": net.vocab_size,
"activation": None,
"is_probabilistic": "TokenEmbed" in probabilistic_layers,
"layer_type": "embedding",
"layer_name": "TokenEmbed"
},
# Position embedding
{
"features": net.d_model,
"num_embeddings": net.max_seq_length,
"activation": None,
"is_probabilistic": "PosEmbed" in probabilistic_layers,
"layer_type": "embedding",
"layer_name": "PosEmbed"
}
])

# 2. Transformer blocks
for i in range(net.num_layers):
# a. Self-attention
configs.append({
"features": net.d_model,
"activation": None,
"is_probabilistic": f"Block{i}_Attention" in probabilistic_layers,
"layer_type": "attention",
"layer_name": f"Block{i}_Attention",
"num_heads": net.nhead,
"qkv_features": net.d_model,
"dropout_rate": net.dropout_rate
})

# b. MLP part
configs.extend([
# First dense layer
{
"features": net.dim_feedforward,
"activation": activation_fn,
"is_probabilistic": f"Block{i}_MLP_dense1" in probabilistic_layers,
"layer_type": "fc",
"layer_name": f"Block{i}_MLP_dense1",
"dropout_rate": net.dropout_rate
},
# Second dense layer
{
"features": net.d_model, # Back to d_model size
"activation": None,
"is_probabilistic": f"Block{i}_MLP_dense2" in probabilistic_layers,
"layer_type": "fc",
"layer_name": f"Block{i}_MLP_dense2",
"dropout_rate": net.dropout_rate
}
])

# 3. Final layers
configs.extend([
# First final dense
{
"features": net.dim_feedforward,
"activation": activation_fn,
"is_probabilistic": "FinalDense1" in probabilistic_layers,
"layer_type": "fc",
"layer_name": "FinalDense1",
"dropout_rate": net.dropout_rate
},
# Output layer
{
"features": 1,
"activation": None, # nn.softmax if net.classification else None,
"is_probabilistic": "FinalDense2" in probabilistic_layers,
"layer_type": "fc",
"layer_name": "FinalDense2",
"dropout_rate": net.dropout_rate
}
])

return configs

0 comments on commit 0c10262

Please sign in to comment.