You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Class to restrict the transform to fewer dimensions for conditional sampling.
The resulting transform transforms only the free dimensions of the conditional.
Notably, the `log_abs_det` is computed given all dimensions. However, the
`log_abs_det` stemming from the fixed dimensions is a constant and drops out during
MCMC.
All methods work in a similar way: `full_theta`` will first have all entries of the
`condition` and then override the entries that should be sampled with `theta`. In
case `theta` is a batch of `theta` (e.g. multi-chain MCMC), we have to repeat
`theta_condition`` to the match the batchsize.
This is needed for the the MCMC initialization functions when conditioning and when
transforming the samples back into the original theta space after sampling.
"""
def__init__(
self,
transform: torch_tf.Transform,
condition: Tensor,
dims_to_sample: List[int],
) ->None:
The conditioning with theta makes the interface incompatible to torch.Transform as e.g. the normal inv() is called without arguments and our inv(theta) is called with an argument. This example could be fixed by renaming it to restricted_inv(theta).
The text was updated successfully, but these errors were encountered:
The problem is caused by the fact that RestrictedTransformForConditionalis a torch Transform, but also takes a torch Transform as an argument. This can be refactored, but I also think this is not priority for the release milestone as conditional_potential is not used by anything at the moment.
RestrictedTransformForConditional
is currently typed astorch.Transform
having a transform as variable:sbi/sbi/utils/conditional_density_utils.py
Lines 386 to 409 in bae6994
The conditioning with
theta
makes the interface incompatible totorch.Transform
as e.g. the normalinv()
is called without arguments and ourinv(theta)
is called with an argument. This example could be fixed by renaming it torestricted_inv(theta)
.The text was updated successfully, but these errors were encountered: