From 4e14d4499a71c3e3a64dd8eeb1a840cc4b6a7543 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 5 Feb 2025 20:01:44 -0800 Subject: [PATCH] Ensure inputs to EmbedModule are integers --- neurobayes/flax_nets/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurobayes/flax_nets/transformer.py b/neurobayes/flax_nets/transformer.py index 10ee2f3..838b395 100644 --- a/neurobayes/flax_nets/transformer.py +++ b/neurobayes/flax_nets/transformer.py @@ -10,12 +10,12 @@ class EmbedModule(nn.Module): @nn.compact def __call__(self, x): + x = x.astype(jnp.int32) return nn.Embed( num_embeddings=self.num_embeddings, features=self.features, name=self.layer_name )(x) - class LayerNormModule(nn.Module): layer_name: str = 'layernorm'