Skip to content

Commit

Permalink
[CODE QUALITY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 19, 2023
1 parent 32a473c commit de73222
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 36 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ print("Decoded audio:", decoded_audio)

```

### `ImgToTransformer`
### `ImgToEmbeddings`
- takes in img -> patches -> reshapes to [B, SEQLEN, Dim] to align with transformer
```python
import torch
from gemini_torch.utils import ImgToTransformer
from gemini_torch.utils import ImgToEmbeddings

# Example usage
num_patches = 16
Expand All @@ -149,7 +149,7 @@ img_channels = 3
seq_len = 50000
reduced_dim = 256 # Reduced dimension after dimensionality reduction

model = ImgToTransformer(
model = ImgToEmbeddings(
num_patches, patch_size, transformer_dim, img_channels, seq_len, reduced_dim
)

Expand All @@ -163,19 +163,19 @@ print(seq_space_output.shape) # Expected shape: [1, 50000, 256]

```

### `AudioToLangEmbedding`
### `AudioToEmbeddings`
- Transforms audio into the same shape as text tensors.

```python
import torch
from gemini_torch.utils import AudioToLangEmbedding
from gemini_torch.utils import AudioToEmbeddings

# Example usage
audio_seq_len = 32000 # Input audio sequence length
seqlen = 512 # Sequence length to align with the language transformer
dim = 512 # Embedding dimension

model = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
model = AudioToEmbeddings(audio_seq_len, seqlen, dim)
audio_input = torch.randn(1, audio_seq_len) # Example input tensor
output = model(audio_input)

Expand Down
6 changes: 3 additions & 3 deletions gemini_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from gemini_torch.model import Gemini
from gemini_torch.utils import ImgToTransformer, AudioToLangEmbedding
from gemini_torch.utils import ImgToEmbeddings, AudioToEmbeddings
from gemini_torch.tokenizer import MultimodalSentencePieceTokenizer

__all__ = [
"Gemini",
"ImgToTransformer",
"AudioToLangEmbedding",
"ImgToEmbeddings",
"AudioToEmbeddings",
"MultimodalSentencePieceTokenizer",
]
6 changes: 3 additions & 3 deletions gemini_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from zeta.structs import AutoregressiveWrapper

from gemini_torch.transformer import Decoder, Transformer
from gemini_torch.utils import ImgToTransformer, AudioToLangEmbedding
from gemini_torch.utils import ImgToEmbeddings, AudioToEmbeddings


def exists(val):
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self.decoder = AutoregressiveWrapper(self.gemini)

# Takes in imgs -> patches them -> transforms them to the same dimension as the model
self.img_to_transformer = ImgToTransformer(
self.img_to_transformer = ImgToEmbeddings(
patches=patches,
patch_size=patch_size,
transformer_dim=dim,
Expand All @@ -103,7 +103,7 @@ def __init__(
)

# Takes in audio -> transforms it to the same dimension as the model
self.audio_to_lang_embedding = AudioToLangEmbedding(
self.audio_to_lang_embedding = AudioToEmbeddings(
audio_seq_len=audio_seq_len, seqlen=num_tokens, dim=dim, *args, **kwargs
)

Expand Down
22 changes: 11 additions & 11 deletions gemini_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from einops import rearrange


class ImgToTransformer(nn.Module):
"""ImgToTransformer
class ImgToEmbeddings(nn.Module):
"""ImgToEmbeddings
Args:
patches (int): Number of patches to divide the image into
Expand All @@ -25,8 +25,8 @@ class ImgToTransformer(nn.Module):
Example:
>>> import torch
>>> from geminix import ImgToTransformer
>>> model = ImgToTransformer(
>>> from geminix import ImgToEmbeddings
>>> model = ImgToEmbeddings(
... patches=16,
... patch_size=16,
... transformer_dim=512,
Expand All @@ -51,7 +51,7 @@ def __init__(
*args,
**kwargs,
):
super(ImgToTransformer, self).__init__()
super(ImgToEmbeddings, self).__init__()
self.patches = patches
self.patch_size = patch_size
self.transformer_dim = transformer_dim
Expand Down Expand Up @@ -129,8 +129,8 @@ def forward(self, x: torch.Tensor):
return x


class AudioToLangEmbedding(nn.Module):
"""AudioToLangEmbedding
class AudioToEmbeddings(nn.Module):
"""AudioToEmbeddings
Args:
audio_seq_len (int): Length of the audio sequence
Expand All @@ -139,8 +139,8 @@ class AudioToLangEmbedding(nn.Module):
Example:
>>> import torch
>>> from geminix import AudioToLangEmbedding
>>> model = AudioToLangEmbedding(
>>> from geminix import AudioToEmbeddings
>>> model = AudioToEmbeddings(
... audio_seq_len=32000,
... seqlen=512,
... dim=512
Expand All @@ -152,7 +152,7 @@ class AudioToLangEmbedding(nn.Module):
"""

def __init__(self, audio_seq_len, seqlen, dim):
super(AudioToLangEmbedding, self).__init__()
super(AudioToEmbeddings, self).__init__()
self.audio_seq_len = audio_seq_len
self.seqlen = seqlen
self.dim = dim
Expand Down Expand Up @@ -185,7 +185,7 @@ def forward(self, x):
# seqlen = 512 # Sequence length to align with the language transformer
# dim = 512 # Embedding dimension

# model = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
# model = AudioToEmbeddings(audio_seq_len, seqlen, dim)
# audio_input = torch.randn(1, audio_seq_len) # Example input tensor
# output = model(audio_input)

Expand Down
14 changes: 7 additions & 7 deletions tests/test_audio_embedder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
import pytest
from gemini_torch.utils import AudioToLangEmbedding
from gemini_torch.utils import AudioToEmbeddings


@pytest.fixture
def audio_embedding():
audio_seq_len = 32000
seqlen = 512
dim = 512
return AudioToLangEmbedding(audio_seq_len, seqlen, dim)
return AudioToEmbeddings(audio_seq_len, seqlen, dim)


def test_forward_pass(audio_embedding):
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_audio_seq_len_equal_seqlen(audio_embedding):
# Test when audio_seq_len is equal to seqlen
audio_seq_len = seqlen = 512
dim = 512
audio_embedding = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
audio_embedding = AudioToEmbeddings(audio_seq_len, seqlen, dim)
input_audio = torch.randn(1, audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)
Expand All @@ -108,7 +108,7 @@ def test_audio_seq_len_less_than_seqlen(audio_embedding):
audio_seq_len = 256
seqlen = 512
dim = 512
audio_embedding = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
audio_embedding = AudioToEmbeddings(audio_seq_len, seqlen, dim)
input_audio = torch.randn(1, audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)
Expand All @@ -119,7 +119,7 @@ def test_audio_seq_len_greater_than_seqlen(audio_embedding):
audio_seq_len = 1024
seqlen = 512
dim = 512
audio_embedding = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
audio_embedding = AudioToEmbeddings(audio_seq_len, seqlen, dim)
input_audio = torch.randn(1, audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)
Expand All @@ -130,7 +130,7 @@ def test_dim_less_than_seqlen(audio_embedding):
audio_seq_len = 32000
seqlen = 512
dim = 256
audio_embedding = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
audio_embedding = AudioToEmbeddings(audio_seq_len, seqlen, dim)
input_audio = torch.randn(1, audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)
Expand All @@ -141,7 +141,7 @@ def test_dim_greater_than_seqlen(audio_embedding):
audio_seq_len = 32000
seqlen = 512
dim = 1024
audio_embedding = AudioToLangEmbedding(audio_seq_len, seqlen, dim)
audio_embedding = AudioToEmbeddings(audio_seq_len, seqlen, dim)
input_audio = torch.randn(1, audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)
8 changes: 4 additions & 4 deletions tests/test_img_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import unittest
from gemini_torch.utils import ImgToTransformer
from gemini_torch.utils import ImgToEmbeddings


class TestImgToTransformer(unittest.TestCase):
class TestImgToEmbeddings(unittest.TestCase):
def setUp(self):
self.model = ImgToTransformer(
self.model = ImgToEmbeddings(
patches=16,
patch_size=16,
transformer_dim=512,
Expand All @@ -15,7 +15,7 @@ def setUp(self):
)

def test_initialization(self):
self.assertIsInstance(self.model, ImgToTransformer)
self.assertIsInstance(self.model, ImgToEmbeddings)

def test_forward_with_img_256(self):
img = torch.randn(1, 3, 256, 256)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_img_to_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from gemini_torch.utils import ImgToTransformer
from gemini_torch.utils import ImgToEmbeddings

# Example usage
num_patches = 16
Expand All @@ -9,7 +9,7 @@
seq_len = 50000
reduced_dim = 256 # Reduced dimension after dimensionality reduction

model = ImgToTransformer(
model = ImgToEmbeddings(
num_patches, patch_size, transformer_dim, img_channels, seq_len, reduced_dim
)

Expand Down

0 comments on commit de73222

Please sign in to comment.