Skip to content

Commit

Permalink
Ensure inputs to EmbedModule are integers
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 6, 2025
1 parent bced2ef commit 4e14d44
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion neurobayes/flax_nets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ class EmbedModule(nn.Module):

@nn.compact
def __call__(self, x):
x = x.astype(jnp.int32)
return nn.Embed(
num_embeddings=self.num_embeddings,
features=self.features,
name=self.layer_name
)(x)


class LayerNormModule(nn.Module):
layer_name: str = 'layernorm'
Expand Down

0 comments on commit 4e14d44

Please sign in to comment.