From dbffc205e1d388d6f922ff9d5e483d29bc2176b6 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 23 Dec 2023 21:46:07 -0500 Subject: [PATCH] [NEW V] [Removal of Alibi+XPOS] --- example.py | 20 ++++++++++---------- gemini_torch/audio_encoder_usm.py | 0 gemini_torch/model.py | 6 ------ {data => tokenizer}/tokenizer.model | Bin 4 files changed, 10 insertions(+), 16 deletions(-) delete mode 100644 gemini_torch/audio_encoder_usm.py rename {data => tokenizer}/tokenizer.model (100%) diff --git a/example.py b/example.py index 9f4751f..19d4255 100644 --- a/example.py +++ b/example.py @@ -1,17 +1,17 @@ import torch from gemini_torch.model import Gemini -# Initialize model +# Initialize model with smaller dimensions model = Gemini( num_tokens=50432, - max_seq_len=8192, - dim=2560, - depth=32, - dim_head=128, - heads=24, + max_seq_len=4096, # Reduced from 8192 + dim=1280, # Reduced from 2560 + depth=16, # Reduced from 32 + dim_head=64, # Reduced from 128 + heads=12, # Reduced from 24 use_abs_pos_emb=False, alibi_pos_bias=True, - alibi_num_heads=12, + alibi_num_heads=6, # Reduced from 12 rotary_xpos=True, attn_flash=True, attn_kv_heads=2, @@ -21,13 +21,13 @@ ) # Text shape: [batch, seq_len, dim] -text = torch.randint(0, 50432, (1, 8192)) +text = torch.randint(0, 50432, (1, 4096)) # Reduced seq_len from 8192 # Img shape: [batch, channels, height, width] -img = torch.randn(1, 3, 256, 256) +img = torch.randn(1, 3, 128, 128) # Reduced height and width from 256 # Audio shape: [batch, audio_seq_len, dim] -audio = torch.randn(1, 128) +audio = torch.randn(1, 64) # Reduced audio_seq_len from 128 # Apply model to text and img y = model(text, img, audio) diff --git a/gemini_torch/audio_encoder_usm.py b/gemini_torch/audio_encoder_usm.py deleted file mode 100644 index e69de29..0000000 diff --git a/gemini_torch/model.py b/gemini_torch/model.py index 9a869fa..24fa54b 100644 --- a/gemini_torch/model.py +++ b/gemini_torch/model.py @@ -46,9 +46,6 @@ def __init__( dim_head=128, heads=24, use_abs_pos_emb=False, - alibi_pos_bias=True, - alibi_num_heads=12, - rotary_xpos=True, attn_flash=True, attn_kv_heads=2, qk_norm=True, @@ -74,9 +71,6 @@ def __init__( depth=depth, dim_head=dim_head, heads=heads, - alibi_pos_bias=alibi_pos_bias, - alibi_num_heads=alibi_num_heads, - rotary_xpos=rotary_xpos, attn_flash=attn_flash, attn_kv_heads=attn_kv_heads, qk_norm=qk_norm, diff --git a/data/tokenizer.model b/tokenizer/tokenizer.model similarity index 100% rename from data/tokenizer.model rename to tokenizer/tokenizer.model