Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: group-query-attention implementation #74

Merged
merged 24 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f0ea511
feat: group-query-attention implementation
fromm-m Jan 30, 2024
ec8c807
chore: merge main into GQA
flxst Mar 11, 2024
5ae1c63
chore: align configs with new GQA keys
lhahn-iis Mar 11, 2024
8da8c1a
docs: add potential removal marker for "scaling_factor"
lhahn-iis Mar 11, 2024
415b0a6
test: add attention forward pass test for GQA
lhahn-iis Mar 11, 2024
6d4a6cf
fix: add verbose check for divisibility of K,V,Q matrix shapes
lhahn-iis Mar 11, 2024
e0e274d
refactor: remove AttentionConfig
flxst Mar 11, 2024
3725a54
debug: expanded KVs for GQA implementation
lhahn-iis Mar 11, 2024
1239381
fix: group query attention implementation
flxst Mar 11, 2024
6d9849d
refactor: test causal self-attention
flxst Mar 12, 2024
65519c8
refactor: test causal self-attention (continued)
flxst Mar 12, 2024
c3242e3
test: causal self-attention type equality
flxst Mar 12, 2024
7d27b59
feat: replace current attention mechanism with `flash-attn`
lhahn-iis Mar 18, 2024
66addf4
Merge branch 'main' into GQA_2
fromm-m Mar 18, 2024
69fd5eb
fix: fixed qkv test
fromm-m Mar 19, 2024
0431479
refactor: refactored flash_attention
fromm-m Mar 19, 2024
bc27773
refactor: simplifiy reshaping and remove unused imports
mali-git Mar 20, 2024
fb88edb
refactor: refactor test
mali-git Mar 20, 2024
7235b73
fix: bugfix
fromm-m Mar 20, 2024
85f2224
Merge branch 'GQA_2' of https://github.com/Modalities/modalities into…
fromm-m Mar 20, 2024
3518ec1
fix: fixed test
fromm-m Mar 20, 2024
a4af491
fix: fix linting issues
fromm-m Mar 20, 2024
14906e9
fix: fixed config
fromm-m Mar 20, 2024
aac9a96
fix: improved the error message
fromm-m Mar 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ venv.bak/
# Vscode
.vscode/
.history/
env_modalities/*
.devcontainer/*

# Created by https://www.toptal.com/developers/gitignore/api/vim
# Edit at https://www.toptal.com/developers/gitignore?templates=vim
Expand Down
8 changes: 2 additions & 6 deletions config_files/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,13 @@ model:
prediction_key: "logits"
block_size: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_layer_q: 12
n_head_kv: 12
mali-git marked this conversation as resolved.
Show resolved Hide resolved
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
activation: gelu
epsilon: 1e-5
weight_init:
mean: 0.0
std: 0.02
Expand Down
5 changes: 2 additions & 3 deletions config_files/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,13 @@ model:
poe_type: NOPE
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_head_q: 12
n_head_kv: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears, like GPT-2. False: a bit better and faster
attention_config:
attention_type: pytorch_flash_attention
scaling_factor: 3
qkv_transforms:
- type_hint: RotaryTransform
config:
Expand Down
7 changes: 2 additions & 5 deletions config_files/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,13 @@ model:
block_size: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_head_q: 12
n_head_kv: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
activation: fused_swiglu
epsilon: 1e-5
weight_init:
mean: 0.0
std: 0.02
Expand Down
13 changes: 5 additions & 8 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,20 @@ model:
prediction_key: ${loss_fn.config.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head: 4
n_head_q: 8
n_head_kv: 8
ffn_hidden: 128
n_embd: 128
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention_config:
attention_type: pytorch_flash_attention
scaling_factor: 3
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head}
n_head: ${model.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
activation_type: gelu
epsilon: 1e-5
weight_init:
mean: 0.0
std: 0.02
Expand Down Expand Up @@ -285,7 +283,6 @@ evaluation_subscriber:
config:
local_rank: ${settings.cuda_env.local_rank}
project: modalities
mode: ONLINE
mode: OFFLINE
experiment_id: ${settings.experiment_id}
directory: "."

directory: "."
7 changes: 3 additions & 4 deletions examples/getting_started/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,13 @@ model:
block_size: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_head_q: 12
n_head_kv: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: pytorch_flash_attention
scaling_factor: 3
attention_type: pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
7 changes: 3 additions & 4 deletions examples/library_usage/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,13 @@ model:
block_size: 256 # TODO reference this (same as sequence length)
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head: 4
n_head_q: 4
n_head_kv: 4
ffn_hidden: 128
n_embd: 128
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: default_attention # pytorch_flash_attention
scaling_factor: 3
attention_type: default_attention # pytorch_flash_attention
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ dependencies = [
"jq",
"xformers",
"class_resolver",
"wandb"
"wandb",
"flash-attn" # install this directly via `pip install flash-attn --no-build-isolation`

]

[project.optional-dependencies]
linting = ["pre-commit"]
tests = ["pytest", "pytest-cov"]
install_helper = ["ninja"]

[project.scripts]
modalities = "modalities.__main__:main"
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/config/look_up_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ class LookupEnum(Enum):
@classmethod
def _missing_(cls, value: str) -> type:
"""constructs Enum by member name, if not constructable by value"""
return cls.__dict__[value]
return cls.__dict__[value]
3 changes: 2 additions & 1 deletion src/modalities/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import BaseModel


def convert_base_model_config_to_dict(config: BaseModel) -> Dict[Any, Any]:
""""Converts non-recursively a Pydantic BaseModel to a dictionary."""
""" "Converts non-recursively a Pydantic BaseModel to a dictionary."""
return {key: getattr(config, key) for key in config.model_dump().keys()}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Dict, Optional

import rich
import wandb
from rich.console import Group
from rich.panel import Panel

import wandb
from modalities.batch import EvaluationResultBatch
from modalities.config.config import WandbMode
from modalities.logging_broker.messages import Message
Expand Down
Loading