Avoid passing deterministic
to complicated nested modules by using variable dict
#2928
-
I am recently working with modules with complicated nested structures. I find it tedious to have to pass parameters like In this WMT example, I think we can do similar here in Flax, by making use of the variable dict and passing their global default values to class Dropout(Module):
rate: float
broadcast_dims: Sequence[int] = ()
deterministic: Optional[bool] = None
@compact
def __call__(self, inputs, deterministic: Optional[bool] = None):
# use variable `training` to indicate deterministic
eval_mode = not self.variable(col='properties',
name='training').value \
if self.has_variable(col='properties', name='training') \
else \
None
deterministic = merge_param('deterministic',
self.deterministic,
deterministic,
eval_mode)
...
# apply dropout if not deterministic
class StochasticDenseBlock(Module):
@compact
def __call__(self, inputs):
# no need to pass deterministic to Dropout
x = Dense(5)(inputs)
x = Dropout(.5)(x)
return relu(x)
class Model(Module):
out_layer: Module
@compact
def __call__(self, inputs):
# no need to pass deterministic to out_layer or any DenseBlock
x = StochasticDenseBlock()(inputs)
x = StochasticDenseBlock()(x)
x = StochasticDenseBlock()(x)
x = self.out_layer(x)
return x
mdl = Model(StochasticDenseBlock())
rng = random.PRNGKey(0)
x = np.ones(10)
params = mdl.init(
rngs=dict(params=rng),
inputs=x,
default={'properties': {'training': False}} # pass global default value for variables absent in variable dict, do not need to have the same nested module structure.
)['params']
def mdl_apply(inputs, training, rng=None):
return mdl.apply(dict(params=params),
inputs=inputs,
rngs=dict(dropout=rng),
default={'properties': {'training': training}})
mdl_apply_train = jit(partial(mdl_apply, training=True))
mdl_apply_eval = jit(partial(mdl_apply, training=False))
y1 = mdl_apply_train(x, rng=rng)
y2 = mdl_apply_eval(x) Following this style, we only need to handle This change maintains backward compatibility to code that passes I have not taken a deep look at Flax internals, but we can probably add a A similar question has been asked here #1561 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
First thing to note is that Flax is very explicit about everything, it doesn't try to do anything for you to give you maximum control. That said, I share the sentiment and in the past created the Scope Flags FLIP (#2131) to try to minimize passing down these parameters, take a look at some of the comments. The current situation is that flags are indeed implemented ( y = module.apply({'params': params}, x, flags={'deterministic': False}) I'll try to bring up this idea again with the team pointing to this use case to see if there is renewed interest. |
Beta Was this translation helpful? Give feedback.
First thing to note is that Flax is very explicit about everything, it doesn't try to do anything for you to give you maximum control. That said, I share the sentiment and in the past created the Scope Flags FLIP (#2131) to try to minimize passing down these parameters, take a look at some of the comments.
The current situation is that flags are indeed implemented (
Module.scope.flags
exists) but currently we only use it to power theModule.is_initializing
method, but we don't expose them and our layers don't use them (apart from their use ofis_initializing
). In theory we could have something like this:I'll try to br…