Skip to content

Commit

Permalink
feat: more transformer definitions such as rope scailing
Browse files Browse the repository at this point in the history
Signed-off-by: Zhao Chen <[email protected]>
  • Loading branch information
aftersnow committed Oct 16, 2024
1 parent b095d7e commit d63f940
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 30 deletions.
30 changes: 30 additions & 0 deletions docs/v2/architecture.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Architecture

## Tensor naming convention

[version].[vendor].[family].[name].[arch].[modality].[block_name].[layer_name].[tensor_name].[tensor_type]

The dot in the name should be replaced with a underscore.

### Naming Conventions

- **version**: The version of the naming convention.
- **vendor**: The vendor of the model.
- **family**: The family of the model.
- **name**: The name of the model.
- **arch**: The architecture of the model.
- **modality**: The modality of the model.
- **block_name**: The name of the block.
- **layer_name**: The name and 0-indexed layer number of the layer.
- **tensor_name**: The name of the tensor.
- **tensor_type**: Weight or bias of the tensor.

### Example

```plain
v1.meta.llama-3_2-1b.transformer.text.decoder.layers_0.embedding.projection.weight
```

```plain
v1.meta.llama-3_2-1b.transformer.text.decoder.layers_1.attention.query.weight
```
50 changes: 43 additions & 7 deletions specs-go/v2/architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ type Architecture struct {
// Transformer architecture
Transformer Transformer `json:"transformer"`

// TODO: Other architectures
// TODO: Other architectures, like mamba, etc.
}

// Transformer represents the transformer architecture.
Expand All @@ -30,6 +30,9 @@ type TransformerForCausalLM struct {
// The hidden size of the model
HiddenSize int `json:"hidden_size"`

// embedding
Embedding Embedding `json:"embedding"`

// Position embedding type
PositionEmbedding PositionEmbedding `json:"position_embedding"`

Expand All @@ -45,8 +48,11 @@ type TransformerForCausalLM struct {

// TransformerLayer represents the transformer layer parameters.
type TransformerLayer struct {
// Attention parameters
Attention Attention `json:"attention"`
MLP MLP `json:"mlp"`

// MLP parameters
MLP MLP `json:"mlp"`
}

// MLP represents the MLP (Multi-Layer Perceptron) parameters.
Expand Down Expand Up @@ -97,6 +103,9 @@ type Attention struct {
// Number of key-value heads
NumKeyValueHeads int `json:"num_key_value_heads"`

// The attention head dimension. If 0, it will default to hidden_size / NumAttentionHeads
HeadDim int `json:"head_dim"`

// Whether the attention has a residual connection
HasResidual bool `json:"has_residual"`

Expand All @@ -112,17 +121,35 @@ type Attention struct {

// PositionEmbedding represents the position embedding type and parameters.
type PositionEmbedding struct {
// Type of position embedding, e.g. 'rope', 'sinusoidal', 'alibi', etc.
// Type of position embedding, e.g. 'rope', 'alibi', etc.
Type string `json:"type"`

// The maximum number of position embeddings
MaxPositionEmbeddings int `json:"max_position_embeddings"`

// The base in signifying the rotary embedding period.
RotaryEmbeddingBase int `json:"rotary_embedding_base,omitempty"`
// Only used with 'RoPE'. The theta parameter in the RoPE position embedding.
RotaryEmbeddingTheta float64 `json:"rope_theta,omitempty"`

// Only used with 'RoPE'. The scaling configuration for the RoPE embeddings
RotaryEmbeddingScaling RotaryEmbeddingScaling `json:"rope_scaling,omitempty"`
}

// RotaryEmbeddingScaling represents the scaling configuration for the RoPE embeddings.
type RotaryEmbeddingScaling struct {
// Type of scaling, can be one of ['default', 'linear', 'dynamic', 'llama3'], with 'default' being the original RoPE implementation.
Type string `json:"type"`

// The scaling factor
Factor float64 `json:"factor"`

// The original max position used during pretraining.
OriginalMaxPosition int `json:"original_max_position"`

// Fraction of hidden size to apply rotary embeddings to. Must be in [0,1].
RotaryEmbeddingFraction float64 `json:"rotary_embedding_fraction,omitempty"`
// Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
LowFreqFactor float64 `json:"low_freq_factor"`

// Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
HighFreqFactor float64 `json:"high_freq_factor"`
}

// Normalization represents the normalization parameters.
Expand All @@ -133,3 +160,12 @@ type Normalization struct {
// Epsilon for the normalization
Epsilon float64 `json:"epsilon"`
}

// Embedding represents the embedding parameters.
type Embedding struct {
// Whether the embedding has a bias
HasBias bool `json:"has_bias"`

// Whether the embedding has a normalization
HasNorm bool `json:"has_norm"`
}
23 changes: 0 additions & 23 deletions specs-go/v2/architecture.md

This file was deleted.

0 comments on commit d63f940

Please sign in to comment.