Skip to content

Commit

Permalink
[EXAMPLE]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 21, 2023
1 parent ee70d8a commit babc10d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 52 deletions.
38 changes: 22 additions & 16 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
import torch
from gemini_torch import Gemini
from gemini_torch.model import Gemini

# Initialize the model
# Initialize model
model = Gemini(
num_tokens=12608,
max_seq_len=2048,
dim=640,
depth=8,
dim_head=32,
heads=6,
num_tokens=50432,
max_seq_len=8192,
dim=2560,
depth=32,
dim_head=128,
heads=24,
use_abs_pos_emb=False,
alibi_pos_bias=True,
alibi_num_heads=3,
alibi_num_heads=12,
rotary_xpos=True,
attn_flash=True,
attn_kv_heads=1,
attn_kv_heads=2,
qk_norm=True,
attn_qk_norm=True,
attn_qk_norm_dim_scale=True,
)

# Initialize the text random tokens
x = torch.randint(0, 12608, (1, 2048))
# Text shape: [batch, seq_len, dim]
text = torch.randint(0, 50432, (1, 8192))

# Apply model to x
y = model(x)
# Img shape: [batch, channels, height, width]
img = torch.randn(1, 3, 256, 256)

# Print logits
print(y)
# Audio shape: [batch, audio_seq_len, dim]
audio = torch.randn(1, 128)

# Apply model to text and img
y = model(text, img, audio)

# Output shape: [batch, seq_len, dim]
print(y.shape)
36 changes: 0 additions & 36 deletions multi_modal_example.py

This file was deleted.

0 comments on commit babc10d

Please sign in to comment.