Skip to content

Commit

Permalink
no grcpio
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 10, 2024
1 parent 9e9dbe1 commit 8ec2a88
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 35 deletions.
109 changes: 109 additions & 0 deletions 2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from einops import rearrange, reduce
from torch import nn


class ImageToTextEmbeddings(nn.Module):
"""
Converts images into text tokens using patch-based embedding.
Args:
patch_size (int): The size of each patch in the image.
dim (int): The dimension of the embedding for each patch.
seq_len (int): The desired sequence length of the text tokens.
Returns:
torch.Tensor: The text tokens representing the input images.
"""
def __init__(self, patch_size: int, dim: int, seq_len: int):
super().__init__()
self.patch_size = patch_size
self.dim = dim
self.seq_len = seq_len
self.projection = nn.Linear(patch_size * patch_size * 3, dim)

def forward(self, images):
# Input images are assumed to be in the shape (batch_size, channels, height, width)
batch_size, _, height, width = images.shape

# Ensure that the image dimensions are divisible by the patch size
assert height % self.patch_size == 0 and width % self.patch_size == 0, \
"Image dimensions must be divisible by the patch size"

# Rearrange the images into patches using einops
patches = rearrange(images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)

# Project the patches into the embedding dimension
embeddings = self.projection(patches)

# Reshape the embeddings into the shape (batch_size, seq_len, dim)
seq_len = (height // self.patch_size) * (width // self.patch_size)
text_tokens = rearrange(embeddings, 'b (h w) e -> b h w e', h=seq_len, w=1)
text_tokens = reduce(text_tokens, 'b h w e -> b h', 'mean')

b, h = text_tokens.shape
proj = nn.Linear(h, self.seq_len)
text_tokens = proj(text_tokens)




return text_tokens

# x = torch.randn(1, 3, 64, 64)
# model = ImageToTextEmbeddings(patch_size=8, dim=512, seq_len=128)
# y = model(x)
# print(y.shape) # Should be [1, 64, 512]


class AudioToEmbeddings(nn.Module):
"""AudioToEmbeddings
Args:
audio_seq_len (int): Length of the audio sequence
seqlen (int): Length of the sequence
dim (int): Embedding dimension
Example:
>>> import torch
>>> from geminix import AudioToEmbeddings
>>> model = AudioToEmbeddings(
... audio_seq_len=32000,
... seqlen=512,
... dim=512
... )
>>> x = torch.randn(1, 32000)
>>> y = model(x)
>>> y.shape
torch.Size([1, 512, 512])
"""

def __init__(self, audio_seq_len: int, seqlen: int,):
super(AudioToEmbeddings, self).__init__()
self.audio_seq_len = audio_seq_len
self.seqlen = seqlen
# Initialize a linear layer to project the 2D audio input to the desired 3D shape
self.projection = nn.Linear(audio_seq_len, seqlen)

def forward(self, x):
"""Forward pass
Args:
x (_type_): _description_
Returns:
_type_: _description_
"""
# x shape: [batch, audio_seq_len] - 2D input
batch, audio_seq_len = x.shape

# Project the audio tensor to match the seqlen and dim
x = self.projection(x) # x shape: [batch, seqlen]

return x

# x = torch.randn(1, 32000)
# model = AudioToEmbeddings(audio_seq_len=32000, seqlen=512)
# y = model(x)
# print(y.shape)
9 changes: 5 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
patches=8, # Reduced from 16
patch_size=8, # Reduced from 16
img_channels=3, # Reduced from 3
# audio_seq_len=32, # Reduced from 64
audio_seq_len=32, # Reduced from 64
)

# Text shape: [batch, seq_len, dim]
Expand All @@ -28,10 +28,11 @@
img = torch.randn(1, 3, 64, 64) # Reduced height and width from 128

# Audio shape: [batch, audio_seq_len, dim]
# audio = torch.randn(1, 32) # Reduced audio_seq_len from 64
audio = torch.randn(1, 32) # Reduced audio_seq_len from 64

# Apply model to text and img
y = model(text, img)
y = model(text=text, img=img, audio=audio)

# Output shape: [batch, seq_len, dim]
print(y)
print(y)
print(y.shape)
49 changes: 29 additions & 20 deletions gemini_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
img_channels: int = 3,
audio_seq_len: int = 128,
*args,
**kwargs
**kwargs,
):
super().__init__()

Expand All @@ -77,21 +77,24 @@ def __init__(
attn_qk_norm=attn_qk_norm,
attn_qk_norm_dim_scale=attn_qk_norm_dim_scale,
*args,
**kwargs
**kwargs,
),
)

# Autoregressive wrapper for the model
# self.decoder = AutoregressiveWrapper(self.gemini)
self.decoder = AutoregressiveWrapper(self.gemini)

# Takes in imgs -> patches them -> transforms them to the same dimension as the model
self.img_to_text_embedding = ImageToTextEmbeddings(
patch_size=patches, dim=dim, seq_len=num_tokens, *args, **kwargs
patch_size=patches, dim=dim, seq_len=max_seq_len, *args, **kwargs
)

# Takes in audio -> transforms it to the same dimension as the model
self.audio_to_lang_embedding = AudioToEmbeddings(
audio_seq_len=audio_seq_len, seqlen=max_seq_len, dim=dim, *args, **kwargs
audio_seq_len=audio_seq_len,
seqlen=max_seq_len,
*args,
**kwargs,
)

except Exception as e:
Expand All @@ -104,7 +107,7 @@ def forward(
img: torch.Tensor = None,
audio: torch.Tensor = None,
*args,
**kwargs
**kwargs,
):
"""
Forward pass of the model.
Expand All @@ -128,24 +131,30 @@ def forward(
try:
if exists(img) and exists(audio):
# Process audio and image inputs
audio_emb = self.audio_to_lang_embedding(audio)
img_emb = self.img_to_transformer(img)
audio = self.audio_to_lang_embedding(audio)
img = self.img_to_text_embedding(img)

# Concatenate text, image, and audio embeddings
x = torch.cat((text, img_emb, audio_emb))

if exists(img):
# x = torch.cat((text, img_emb, audio_emb))
fused = torch.cat((text, img, audio))
return self.decoder(text, prepend_embeds=fused, *args, **kwargs)
elif exists(img):
# Process image input
x = self.img_to_text_embedding(img)
print(f"Image shape: {x.shape}")
x = torch.cat((text, x))
print(f"Concat shape: {x.shape}")
return x
img = self.img_to_text_embedding(img)
# print(f"Image shape: {x.shape}")
# x = torch.cat((text, x))
# print(f"Concat shape: {x.shape}")
# return x
return self.decoder(text, prepend_embeds=img, *args, **kwargs)
elif exists(audio):
# Process audio input
audio = self.audio_to_lang_embedding(audio)
# x = torch.cat((text, x), dim=1)
# return audio
# Call the forward method of the decoder once
return self.decoder(text, prepend_embeds=audio, *args, **kwargs)
else:
x = text

# Call the forward method of the decoder once
return self.decoder(x, padded_x=x)
return self.decoder(text, *args, **kwargs)
except Exception as e:
print("Failed in forward method: ", e)
raise
29 changes: 18 additions & 11 deletions gemini_torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from einops import rearrange, reduce
from torch import nn

Expand All @@ -15,20 +16,17 @@ class ImageToTextEmbeddings(nn.Module):
torch.Tensor: The text tokens representing the input images.
"""
def __init__(self, patch_size, dim, seq_len):
def __init__(self, patch_size: int, dim: int, seq_len: int):
super().__init__()
self.patch_size = patch_size
self.dim = dim
self.seq_len = seq_len
self.projection = nn.Linear(patch_size * patch_size * 3, dim)
# self.seq_proj = nn.Linear(dim, seq_len)


def forward(self, images):
# Input images are assumed to be in the shape (batch_size, channels, height, width)
batch_size, _, height, width = images.shape

seq_proj = nn.Linear(height, self.seq_len)

# Ensure that the image dimensions are divisible by the patch size
assert height % self.patch_size == 0 and width % self.patch_size == 0, \
"Image dimensions must be divisible by the patch size"
Expand All @@ -42,12 +40,12 @@ def forward(self, images):
# Reshape the embeddings into the shape (batch_size, seq_len, dim)
seq_len = (height // self.patch_size) * (width // self.patch_size)
text_tokens = rearrange(embeddings, 'b (h w) e -> b h w e', h=seq_len, w=1)
text_tokens = reduce(text_tokens, "b h w e -> b h (w e)", "mean")

# Project the embeddings into the sequence length, in the 2nd dimension
text_tokens = rearrange(text_tokens, "b h d -> b d h", h=seq_len)
text_tokens = reduce(text_tokens, "b h w e -> b (w e h)", "mean")
seq_proj = nn.Linear(seq_len, self.seq_len)
text_tokens = seq_proj(text_tokens)
text_tokens = rearrange(text_tokens, "b d h -> b h d")




return text_tokens

Expand Down Expand Up @@ -85,7 +83,7 @@ def __init__(self, audio_seq_len: int, seqlen: int, dim: int):
self.seqlen = seqlen
self.dim = dim
# Initialize a linear layer to project the 2D audio input to the desired 3D shape
self.projection = nn.Linear(audio_seq_len, seqlen * dim)
self.projection = nn.Linear(audio_seq_len, dim)

def forward(self, x):
"""Forward pass
Expand All @@ -106,3 +104,12 @@ def forward(self, x):
x = rearrange(x, "b (s d) -> b s d", s=self.seqlen, d=self.dim)

return x

x = torch.randn(1, 32000)
model = AudioToEmbeddings(
audio_seq_len=32000,
seqlen=512,
dim=512
)
y = model(x)
print(y.shape) # Should be [1, 512, 512]

0 comments on commit 8ec2a88

Please sign in to comment.