Skip to content

Commit

Permalink
[CQ]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 13, 2023
1 parent 4e204b1 commit ad2e807
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
4 changes: 2 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch
import torch
from gemini_torch import Gemini

# Initialize the model
Expand Down Expand Up @@ -27,4 +27,4 @@
y = model(x)

# Print logits
print(y)
print(y)
2 changes: 0 additions & 2 deletions gemini_torch/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,3 @@ def decode(self, tokens: List[int]) -> str:
for start_id, end_id in self.modality_tokens.values():
tokens = [t for t in tokens if t not in (start_id, end_id)]
return self.sp_model.decode(tokens)


23 changes: 20 additions & 3 deletions tests/test_audio_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,81 @@
import pytest
from gemini_torch.utils import AudioToLangEmbedding


@pytest.fixture
def audio_embedding():
audio_seq_len = 32000
seqlen = 512
dim = 512
return AudioToLangEmbedding(audio_seq_len, seqlen, dim)


def test_forward_pass(audio_embedding):
# Test the forward pass with a random input
batch_size = 2
input_audio = torch.randn(batch_size, audio_embedding.audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (batch_size, audio_embedding.seqlen, audio_embedding.dim)


def test_device_placement(audio_embedding):
# Test if the model and input/output tensors are on the same device
input_audio = torch.randn(1, audio_embedding.audio_seq_len)
output = audio_embedding(input_audio)
assert input_audio.device == output.device
assert input_audio.device == audio_embedding.projection.weight.device


def test_output_shape(audio_embedding):
# Test if the output shape matches the expected shape
input_audio = torch.randn(1, audio_embedding.audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, audio_embedding.seqlen, audio_embedding.dim)


def test_batch_processing(audio_embedding):
# Test batch processing by passing a batch of input tensors
batch_size = 4
input_audio = torch.randn(batch_size, audio_embedding.audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (batch_size, audio_embedding.seqlen, audio_embedding.dim)


def test_zero_input(audio_embedding):
# Test with zero input
input_audio = torch.zeros(1, audio_embedding.audio_seq_len)
output = audio_embedding(input_audio)
assert torch.all(output == 0)


def test_negative_input(audio_embedding):
# Test with negative input values
input_audio = torch.randn(1, audio_embedding.audio_seq_len) - 2.0
output = audio_embedding(input_audio)
assert torch.all(output < 0)


def test_large_input(audio_embedding):
# Test with large input values
input_audio = torch.randn(1, audio_embedding.audio_seq_len) * 100.0
output = audio_embedding(input_audio)
assert torch.all(output > 0)


def test_input_shape_mismatch(audio_embedding):
# Test if an error is raised for an input shape mismatch
with pytest.raises(torch.nn.modules.module.ModuleAttributeError):
input_audio = torch.randn(1, audio_embedding.audio_seq_len + 1)
audio_embedding(input_audio)


def test_output_device(audio_embedding):
# Test if the output device matches the expected device
input_audio = torch.randn(1, audio_embedding.audio_seq_len).to('cuda')
audio_embedding.to('cuda')
input_audio = torch.randn(1, audio_embedding.audio_seq_len).to("cuda")
audio_embedding.to("cuda")
output = audio_embedding(input_audio)
assert output.device == torch.device('cuda')
assert output.device == torch.device("cuda")


def test_large_batch_size(audio_embedding):
# Test with a large batch size
Expand All @@ -74,12 +85,14 @@ def test_large_batch_size(audio_embedding):
output = audio_embedding(input_audio)
assert output.shape == (batch_size, audio_embedding.seqlen, audio_embedding.dim)


def test_small_batch_size(audio_embedding):
# Test with a small batch size (1)
input_audio = torch.randn(1, audio_embedding.audio_seq_len)
output = audio_embedding(input_audio)
assert output.shape == (1, audio_embedding.seqlen, audio_embedding.dim)


def test_audio_seq_len_equal_seqlen(audio_embedding):
# Test when audio_seq_len is equal to seqlen
audio_seq_len = seqlen = 512
Expand All @@ -89,6 +102,7 @@ def test_audio_seq_len_equal_seqlen(audio_embedding):
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)


def test_audio_seq_len_less_than_seqlen(audio_embedding):
# Test when audio_seq_len is less than seqlen
audio_seq_len = 256
Expand All @@ -99,6 +113,7 @@ def test_audio_seq_len_less_than_seqlen(audio_embedding):
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)


def test_audio_seq_len_greater_than_seqlen(audio_embedding):
# Test when audio_seq_len is greater than seqlen
audio_seq_len = 1024
Expand All @@ -109,6 +124,7 @@ def test_audio_seq_len_greater_than_seqlen(audio_embedding):
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)


def test_dim_less_than_seqlen(audio_embedding):
# Test when dim is less than seqlen
audio_seq_len = 32000
Expand All @@ -119,6 +135,7 @@ def test_dim_less_than_seqlen(audio_embedding):
output = audio_embedding(input_audio)
assert output.shape == (1, seqlen, dim)


def test_dim_greater_than_seqlen(audio_embedding):
# Test when dim is greater than seqlen
audio_seq_len = 32000
Expand Down

0 comments on commit ad2e807

Please sign in to comment.