Skip to content

Commit

Permalink
Fix the dtype selection
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 4, 2025
1 parent 8024315 commit b5c7ac5
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions neurobayes/flax_nets/deterministic_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __init__(self,
input_shape = (input_shape,) if isinstance(input_shape, int) else input_shape
self.model = architecture

is_transformer = any(base.__name__.lower().find('transformer') >= 0
for base in architecture.__mro__)
is_transformer = 'transformer' in self.model.__class__.__name__.lower()
input_dtype = jnp.int32 if is_transformer else jnp.float32

if loss not in ['homoskedastic', 'heteroskedastic', 'classification']:
Expand Down

0 comments on commit b5c7ac5

Please sign in to comment.