Skip to content

Commit

Permalink
Merge branch 'master' into retinanet-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sineeli committed Feb 3, 2025
2 parents b6b1191 + c664457 commit c1eaf64
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions keras_hub/src/models/pali_gemma/pali_gemma_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,8 @@ def __init__(
self.intermediate_dim = intermediate_dim

def compute_attention(self, x, mask=None):
mask = None
if mask is not None:
mask = ops.cast(mask, dtype=x.dtype) if mask is not None else None
mask = ops.cast(mask, dtype=x.dtype)
return self.attn(x, attention_mask=mask)[0]

def build(self, input_shape):
Expand Down
4 changes: 2 additions & 2 deletions requirements-torch-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tensorflow-text~=2.18

# Torch with cuda support.
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.5.1+cu121
torchvision==0.20.1+cu121
torch==2.6.0+cpu
torchvision==0.21.0+cpu

# Jax cpu-only version.
jax[cpu]
Expand Down

0 comments on commit c1eaf64

Please sign in to comment.