Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: group-query-attention implementation #74

Merged
merged 24 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f0ea511
feat: group-query-attention implementation
fromm-m Jan 30, 2024
ec8c807
chore: merge main into GQA
flxst Mar 11, 2024
5ae1c63
chore: align configs with new GQA keys
lhahn-iis Mar 11, 2024
8da8c1a
docs: add potential removal marker for "scaling_factor"
lhahn-iis Mar 11, 2024
415b0a6
test: add attention forward pass test for GQA
lhahn-iis Mar 11, 2024
6d4a6cf
fix: add verbose check for divisibility of K,V,Q matrix shapes
lhahn-iis Mar 11, 2024
e0e274d
refactor: remove AttentionConfig
flxst Mar 11, 2024
3725a54
debug: expanded KVs for GQA implementation
lhahn-iis Mar 11, 2024
1239381
fix: group query attention implementation
flxst Mar 11, 2024
6d9849d
refactor: test causal self-attention
flxst Mar 12, 2024
65519c8
refactor: test causal self-attention (continued)
flxst Mar 12, 2024
c3242e3
test: causal self-attention type equality
flxst Mar 12, 2024
7d27b59
feat: replace current attention mechanism with `flash-attn`
lhahn-iis Mar 18, 2024
66addf4
Merge branch 'main' into GQA_2
fromm-m Mar 18, 2024
69fd5eb
fix: fixed qkv test
fromm-m Mar 19, 2024
0431479
refactor: refactored flash_attention
fromm-m Mar 19, 2024
bc27773
refactor: simplifiy reshaping and remove unused imports
mali-git Mar 20, 2024
fb88edb
refactor: refactor test
mali-git Mar 20, 2024
7235b73
fix: bugfix
fromm-m Mar 20, 2024
85f2224
Merge branch 'GQA_2' of https://github.com/Modalities/modalities into…
fromm-m Mar 20, 2024
3518ec1
fix: fixed test
fromm-m Mar 20, 2024
a4af491
fix: fix linting issues
fromm-m Mar 20, 2024
14906e9
fix: fixed config
fromm-m Mar 20, 2024
aac9a96
fix: improved the error message
fromm-m Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions config_files/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,13 @@ model:
prediction_key: "logits"
block_size: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_layer_q: 12
n_head_kv: 12
mali-git marked this conversation as resolved.
Show resolved Hide resolved
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
attention_type: pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
7 changes: 3 additions & 4 deletions config_files/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,13 @@ model:
block_size: ${settings.training.sequence_length}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_head_q: 12
n_head_kv: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
attention_type: pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
7 changes: 3 additions & 4 deletions config_files/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,13 @@ model:
block_size: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_head_q: 12
n_head_kv: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
attention_type: pytorch_flash_attention
activation: fused_swiglu
epsilon: 1e-5
weight_init:
Expand Down
7 changes: 3 additions & 4 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,13 @@ model:
block_size: 256 # TODO reference this (same as sequence length)
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head: 4
n_head_q: 8
n_head_kv: 2
ffn_hidden: 128
n_embd: 128
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: default_attention # pytorch_flash_attention
scaling_factor: 3
attention_type: default_attention # pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
7 changes: 3 additions & 4 deletions examples/getting_started/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,13 @@ model:
block_size: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_head_q: 12
n_head_kv: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
attention_type: pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
7 changes: 3 additions & 4 deletions examples/library_usage/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,13 @@ model:
block_size: 256 # TODO reference this (same as sequence length)
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head: 4
n_head_q: 4
n_head_kv: 4
ffn_hidden: 128
n_embd: 128
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: default_attention # pytorch_flash_attention
scaling_factor: 3
attention_type: default_attention # pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
117 changes: 86 additions & 31 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ class ActivationType(str, Enum):
FUSED_SWIGLU = "fused_swiglu"


class AttentionConfig(BaseModel):
attention_type: AttentionType
scaling_factor: Annotated[int, Field(strict=True, ge=1)]


class WeightInitailizationConfig(BaseModel):
mean: Annotated[float, Field(strict=True, ge=0.0)]
std: Annotated[float, Field(strict=True, ge=0.0)]
Expand All @@ -42,13 +37,14 @@ class GPT2LLMConfig(BaseModel):
int, Field(strict=True, ge=1)
] # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: Annotated[int, Field(strict=True, ge=1)]
n_head: Annotated[int, Field(strict=True, ge=1)]
n_head_q: Annotated[int, Field(strict=True, ge=1)]
n_head_kv: Annotated[int, Field(strict=True, ge=1)]
n_embd: Annotated[int, Field(strict=True, ge=1)]
ffn_hidden: Annotated[int, Field(strict=True, ge=1)]

dropout: Annotated[float, Field(strict=True, ge=0.0)]
bias: bool # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention: AttentionConfig
attention_type: AttentionType
activation: ActivationType
epsilon: Annotated[float, Field(strict=True, ge=0.0)]
weight_init: WeightInitailizationConfig
Expand Down Expand Up @@ -85,14 +81,41 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

class CausalSelfAttention(nn.Module):
def __init__(
self, n_head: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int
self,
n_head_q: int,
n_head_kv: int,
n_embd: int,
attention_type: AttentionType,
bias: bool,
dropout: float,
block_size: int,
):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(
assert n_embd % n_head_q == 0, (
"Embeddings get passed to `n_head_q` different heads "
"and their dimension needs to be divisible by `n_head_q`."
)
assert n_head_q % n_head_kv == 0, (
"It is necessary to have `n_head_q` divisible by `n_head_kv`."
' For more details, read about "Grouped Query Attention"'
)

self.n_rep = n_head_q // n_head_kv

# query, key, value projections (separate)
self.q_attn = nn.Linear(
in_features=n_embd,
out_features=attention.scaling_factor * n_embd,
out_features=n_embd,
bias=bias,
)
self.k_attn = nn.Linear(
in_features=n_embd,
out_features=n_embd // self.n_rep,
bias=bias,
)
self.v_attn = nn.Linear(
in_features=n_embd,
out_features=n_embd // self.n_rep,
bias=bias,
)

Expand All @@ -106,10 +129,12 @@ def __init__(
# regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.n_head = n_head
self.n_head_q = n_head_q
self.n_head_kv = n_head_kv

self.n_embd = n_embd
self.dropout = dropout
self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION
self.flash = attention_type == AttentionType.PYTORCH_FLASH_ATTENTION

if not self.flash:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand All @@ -119,15 +144,22 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
B, T, _ = x.size() # batch size (B), sequence length (T), embedding dimensionality (self.n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.q_attn(x) # (B, T, n_embd)
k = self.k_attn(x) # (B, T, n_embd / n_rep)
v = self.v_attn(x) # (B, T, n_embd / n_rep)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
q = q.view(B, T, self.n_head_q, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_q, T, hs)
k = k.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs)
v = v.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs)

# repeat k/v heads if self.n_rep > 1
k = repeat_kv(k, self.n_rep) # (B, nh_q, T, hs)
v = repeat_kv(v, self.n_rep) # (B, nh_q, T, hs)

# causal self-attention; Self-attend: (B, nh_q, T, hs) x (B, nh_q, hs, T) -> (B, nh_q, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -140,15 +172,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh_q, T, T)
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = att @ v # (B, nh_q, T, T) x (B, nh_q, T, hs) -> (B, nh_q, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, self.n_embd)
) # (B, T, n_embd), re-assemble all head outputs side by side

# output projection
y = self.resid_dropout(self.c_proj(y))
y = self.resid_dropout(self.c_proj(y)) # (B, T, n_embd)
return y


Expand Down Expand Up @@ -183,16 +217,23 @@ def __init__(
bias: bool,
epsilon: float,
activation: ActivationType,
n_head: int,
attention: AttentionConfig,
n_head_q: int,
n_head_kv: int,
attention_type: AttentionType,
dropout: float,
block_size: int,
ffn_hidden: int,
):
super().__init__()
self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)
self.attn = CausalSelfAttention(
n_head=n_head, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size
n_head_q=n_head_q,
n_head_kv=n_head_kv,
n_embd=n_embd,
attention_type=attention_type,
bias=bias,
dropout=dropout,
block_size=block_size,
)
self.ln_2 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)

Expand All @@ -218,12 +259,13 @@ def __init__(
block_size: int,
vocab_size: int,
n_layer: int,
n_head: int,
n_head_q: int,
n_head_kv: int,
n_embd: int,
ffn_hidden: int,
dropout: float,
bias: bool,
attention: AttentionConfig,
attention_type: AttentionType,
activation: ActivationType,
epsilon: float,
weight_init: WeightInitailizationConfig,
Expand All @@ -248,8 +290,9 @@ def __init__(
bias=bias,
epsilon=epsilon,
activation=activation,
n_head=n_head,
attention=attention,
n_head_q=n_head_q,
n_head_kv=n_head_kv,
attention_type=attention_type,
dropout=dropout,
block_size=block_size,
ffn_hidden=ffn_hidden,
Expand Down Expand Up @@ -301,3 +344,15 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.forward_impl(inputs)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Source code adopted from
https://github.com/facebookresearch/llama/blob/9a001c7a0987afd7b8de94e538916eff8950a73a/llama/model.py#L164
Adapted ordered dimensions and namings: bs=B, n_kv_heads=nh_kv, slen=T, head_dim=hs
"""
B, nh_kv, T, hs = x.shape
if n_rep == 1:
return x
return x[:, :, None, :, :].expand(B, nh_kv, n_rep, T, hs).reshape(B, nh_kv * n_rep, T, hs)
7 changes: 3 additions & 4 deletions tests/checkpointing/gpt2_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ model:
block_size: 256 # TODO reference this (same as sequence length)
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head: 4
n_head_q: 4
n_head_kv: 4
ffn_hidden: 128
n_embd: 128
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: default_attention # pytorch_flash_attention
scaling_factor: 3
attention_type: default_attention # pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
Loading