Skip to content

Commit

Permalink
[BUFG][just images for now]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 9, 2024
1 parent 9f8da27 commit dc002a3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 2 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
img = torch.randn(1, 3, 64, 64) # Reduced height and width from 128

# Audio shape: [batch, audio_seq_len, dim]
audio = torch.randn(1, 32) # Reduced audio_seq_len from 64
# audio = torch.randn(1, 32) # Reduced audio_seq_len from 64

# Apply model to text and img
y = model(text, img, audio)
y = model(text, img)

# Output shape: [batch, seq_len, dim]
print(y)
6 changes: 5 additions & 1 deletion gemini_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ def forward(
img_emb = self.img_to_transformer(img)

# Concatenate text, image, and audio embeddings
x = torch.cat((text, img_emb, audio_emb), dim=1)
x = torch.cat((text, img_emb, audio_emb))

if exists(img):
# Process image input
x = self.img_to_text_embedding(img)
else:
x = text

Expand Down

0 comments on commit dc002a3

Please sign in to comment.