Skip to content

Commit

Permalink
[FEAT][ImageToTextEmbeddings]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 9, 2024
1 parent c2c98b9 commit 0658bf4
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 160 deletions.
14 changes: 4 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@ The open source implementation of Gemini, the model that will "eclipse ChatGPT",

[Join the Agora discord channel to help with the implementation!](https://discord.gg/CMDpRxCV8g) and [Here is the project board:](https://github.com/users/kyegomez/projects/11/views/1)

The input sequences for Gemini consist of texts, audio, images, and videos. These inputs are transformed into tokens, which are then processed by a transformer. Subsequently, conditional decoding takes place to generate image outputs.

Interestingly, the architecture of Gemini bears resemblance to Fuyu's architecture but is expanded to encompass multiple modalities. Instead of utilizing a visual transformer (vit) encoder, Gemini simply feeds image embeddings directly into the transformer.

For Gemini, the token inputs will likely be indicated by special modality tokens such as [IMG], <img>, [AUDIO], or <audio>. Codi, a component of Gemini, also employs conditional generation and makes use of the tokenized outputs.

To implement this model effectively, I intend to initially focus on the image embeddings to ensure their smooth integration. Subsequently, I will proceed with incorporating audio embeddings and then video embeddings.
The input sequences for Gemini consist of texts, audio, images, and videos. These inputs are transformed into tokens, which are then processed by a transformer. Subsequently, conditional decoding takes place to generate image outputs. Interestingly, the architecture of Gemini bears resemblance to Fuyu's architecture but is expanded to encompass multiple modalities. Instead of utilizing a visual transformer (vit) encoder, Gemini simply feeds image embeddings directly into the transformer. For Gemini, the token inputs will likely be indicated by special modality tokens such as [IMG], <img>, [AUDIO], or <audio>. Codi, a component of Gemini, also employs conditional generation and makes use of the tokenized outputs. To implement this model effectively, I intend to initially focus on the image embeddings to ensure their smooth integration. Subsequently, I will proceed with incorporating audio embeddings and then video embeddings.

# Install
`pip3 install gemini-torch`
Expand Down Expand Up @@ -140,11 +134,11 @@ print("Decoded audio:", decoded_audio)

```

### `ImgToEmbeddings`
### `ImageToTextEmbeddings`
- takes in img -> patches -> reshapes to [B, SEQLEN, Dim] to align with transformer
```python
import torch
from gemini_torch.utils import ImgToEmbeddings
from gemini_torch.utils import ImageToTextEmbeddings

# Example usage
num_patches = 16
Expand All @@ -154,7 +148,7 @@ img_channels = 3
seq_len = 50000
reduced_dim = 256 # Reduced dimension after dimensionality reduction

model = ImgToEmbeddings(
model = ImageToTextEmbeddings(
num_patches, patch_size, transformer_dim, img_channels, seq_len, reduced_dim
)

Expand Down
4 changes: 2 additions & 2 deletions gemini_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from gemini_torch.model import Gemini
from gemini_torch.utils import ImgToEmbeddings, AudioToEmbeddings
from gemini_torch.utils import ImageToTextEmbeddings, AudioToEmbeddings
from gemini_torch.tokenizer import MultimodalSentencePieceTokenizer

__all__ = [
"Gemini",
"ImgToEmbeddings",
"ImageToTextEmbeddings",
"AudioToEmbeddings",
"MultimodalSentencePieceTokenizer",
]
15 changes: 4 additions & 11 deletions gemini_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from zeta.structs import AutoregressiveWrapper

from gemini_torch.transformer import Decoder, Transformer
from gemini_torch.utils import ImgToEmbeddings, AudioToEmbeddings
from gemini_torch.utils import ImageToTextEmbeddings, AudioToEmbeddings


def exists(val):
Expand Down Expand Up @@ -85,15 +85,8 @@ def __init__(
# self.decoder = AutoregressiveWrapper(self.gemini)

# Takes in imgs -> patches them -> transforms them to the same dimension as the model
self.img_to_transformer = ImgToEmbeddings(
patches=patches,
patch_size=patch_size,
transformer_dim=dim,
img_channels=img_channels,
seq_len=num_tokens,
reduced_dim=dim,
*args,
**kwargs
self.img_to_text_embedding = ImageToTextEmbeddings(
patch_size=patches, dim=dim, seq_len=max_seq_len, *args, **kwargs
)

# Takes in audio -> transforms it to the same dimension as the model
Expand All @@ -103,7 +96,7 @@ def __init__(

except Exception as e:
print("Failed to initialize gemini: ", e)
raise
raise e

def forward(
self,
Expand Down
178 changes: 47 additions & 131 deletions gemini_torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,61 @@
import torch
from einops import rearrange, reduce
from torch import nn
from einops import rearrange


class ImgToEmbeddings(nn.Module):
"""ImgToEmbeddings
class ImageToTextEmbeddings(nn.Module):
"""
Converts images into text tokens using patch-based embedding.
Args:
patches (int): Number of patches to divide the image into
patch_size (int): Size of the patches
transformer_dim (int): Dimension of the transformer
img_channels (int): Number of channels in the image
seq_len (int): Length of the sequence
reduced_dim (int): Dimension of the reduced embedding
patch_size (int): The size of each patch in the image.
dim (int): The dimension of the embedding for each patch.
seq_len (int): The desired sequence length of the text tokens.
Returns:
torch.Tensor: The output of the model
Input shape:
(batch, channels, height, width)
Output shape:
(batch, seq_len, reduced_dim)
torch.Tensor: The text tokens representing the input images.
Example:
>>> import torch
>>> from geminix import ImgToEmbeddings
>>> model = ImgToEmbeddings(
... patches=16,
... patch_size=16,
... transformer_dim=512,
... img_channels=3,
... seq_len=128,
... reduced_dim=128
... )
>>> x = torch.randn(1, 3, 256, 256)
>>> y = model(x)
>>> y.shape
torch.Size([1, 128, 128])
"""

def __init__(
self,
patches: int,
patch_size: int,
transformer_dim: int,
img_channels: int,
seq_len: int,
reduced_dim: int,
*args,
**kwargs,
):
super(ImgToEmbeddings, self).__init__()
self.patches = patches
def __init__(self, patch_size, dim, seq_len):
super().__init__()
self.patch_size = patch_size
self.transformer_dim = transformer_dim
self.img_channels = img_channels
self.dim = dim
self.seq_len = seq_len
self.reduced_dim = reduced_dim

# Img is a square, cal number of apthces
self.num_patches_side = int(patches**0.5)

# Patch embedding layer
self.patch_embedding = nn.Linear(
patch_size * patch_size * img_channels, transformer_dim
)

# Dim reduction
self.dim_reduction = nn.Linear(transformer_dim, reduced_dim)

# # Batch Norm and relu
# self.norm = nn.BatchNorm1d(patches)
# self.activate = nn.ReLU()

# # Positional encoding
# self.positional_encoding = nn.Parameter(torch.zeros(1, patches, reduced_dim))

# Token mixing
self.token_mixer = nn.Linear(patches * reduced_dim, patches * reduced_dim)

# Linear layer to expand the seq to vocab
self.seq_expansion = nn.Linear(patches * reduced_dim, seq_len * reduced_dim)

def forward(self, x: torch.Tensor):
"""Forward pass
Args:
x (torch.Tensor): _description_
Returns:
_type_: _description_
"""
batch, channels, height, width, height = x.shape

# Check if img can be evenly divided into patches
assert (
height % self.num_patches_side == 0 and width % self.num_patches_side == 0
), "Image dimensions must be divisivle by the square root of patches"

# Reshpe the img to patches
x = x.unfold(
2,
self.patch_size,
).unfold(3, self.patch_size, self.patch_size)
x = x.contiguous().view(batch, channels, self.num_patches, -1)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, self.num_patches, -1)

# Apply patch embedding
x = self.patch_embedding(x)

# Dim reduction
x = self.dim_reduction(x)

# Batch norm
# x = self.norm(x)
# x = self.activate(x)

# Add positional encoding
x = x.view(batch, -1)
x = self.token_mixer(x)
x = x.view(batch, self.num_patches, -1)

# Expand the seq to match vocab
x = self.seq_expansion(x)
x = x.view(batch, self.seq_len, -1)

return x

self.projection = nn.Linear(patch_size * patch_size * 3, dim)
# self.seq_proj = nn.Linear(dim, seq_len)

def forward(self, images):
# Input images are assumed to be in the shape (batch_size, channels, height, width)
batch_size, _, height, width = images.shape

seq_proj = nn.Linear(height, self.seq_len)

# Ensure that the image dimensions are divisible by the patch size
assert height % self.patch_size == 0 and width % self.patch_size == 0, \
"Image dimensions must be divisible by the patch size"

# Rearrange the images into patches using einops
patches = rearrange(images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)

# Project the patches into the embedding dimension
embeddings = self.projection(patches)

# Reshape the embeddings into the shape (batch_size, seq_len, dim)
seq_len = (height // self.patch_size) * (width // self.patch_size)
text_tokens = rearrange(embeddings, 'b (h w) e -> b h w e', h=seq_len, w=1)
text_tokens = reduce(text_tokens, "b h w e -> b h (w e)", "mean")

# Project the embeddings into the sequence length, in the 2nd dimension
text_tokens = rearrange(text_tokens, "b h d -> b d h", h=seq_len)
text_tokens = seq_proj(text_tokens)
text_tokens = rearrange(text_tokens, "b d h -> b h d")

return text_tokens

# x = torch.randn(1, 3, 64, 64)
# model = ImageToTextEmbeddings(patch_size=8, dim=512, seq_len=128)
# y = model(x)
# print(y.shape) # Should be [1, 64, 512]


class AudioToEmbeddings(nn.Module):
"""AudioToEmbeddings
Expand Down Expand Up @@ -178,15 +106,3 @@ def forward(self, x):
x = rearrange(x, "b (s d) -> b s d", s=self.seqlen, d=self.dim)

return x


# # Example usage
# audio_seq_len = 32000 # Input audio sequence length
# seqlen = 512 # Sequence length to align with the language transformer
# dim = 512 # Embedding dimension

# model = AudioToEmbeddings(audio_seq_len, seqlen, dim)
# audio_input = torch.randn(1, audio_seq_len) # Example input tensor
# output = model(audio_input)

# print("Output shape:", output.shape) # Should be [1, 512, 512]
8 changes: 4 additions & 4 deletions tests/test_img_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import unittest
from gemini_torch.utils import ImgToEmbeddings
from gemini_torch.utils import ImageToTextEmbeddings


class TestImgToEmbeddings(unittest.TestCase):
class TestImageToTextEmbeddings(unittest.TestCase):
def setUp(self):
self.model = ImgToEmbeddings(
self.model = ImageToTextEmbeddings(
patches=16,
patch_size=16,
transformer_dim=512,
Expand All @@ -15,7 +15,7 @@ def setUp(self):
)

def test_initialization(self):
self.assertIsInstance(self.model, ImgToEmbeddings)
self.assertIsInstance(self.model, ImageToTextEmbeddings)

def test_forward_with_img_256(self):
img = torch.randn(1, 3, 256, 256)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_img_to_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from gemini_torch.utils import ImgToEmbeddings
from gemini_torch.utils import ImageToTextEmbeddings

# Example usage
num_patches = 16
Expand All @@ -9,7 +9,7 @@
seq_len = 50000
reduced_dim = 256 # Reduced dimension after dimensionality reduction

model = ImgToEmbeddings(
model = ImageToTextEmbeddings(
num_patches, patch_size, transformer_dim, img_channels, seq_len, reduced_dim
)

Expand Down

0 comments on commit 0658bf4

Please sign in to comment.