Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variably sizes observations #1324

Open
manuelgloeckler opened this issue Dec 10, 2024 · 1 comment
Open

Variably sizes observations #1324

manuelgloeckler opened this issue Dec 10, 2024 · 1 comment
Labels
enhancement New feature or request hackathon

Comments

@manuelgloeckler
Copy link
Contributor

Is your feature request related to a problem? Please describe.

Handling variable-size observations, such as those used with permutation-invariant embedding networks, RNNs, or Transformers, currently requires padding inputs (e.g., with NaNs) to a fixed size. While this approach is nice for batching during training, at test time, it's preferable to also support working with tensors of varying lengths directly.

Unfortunately, the current input_shape checks prevent this, even when the underlying methods could handle variable-length inputs without issue.

As a workaround, it's necessary to manually override the inferred shapes to bypass these checks:

x_o = torch.tensor(np.array(x_o))
posterior._x_shape = (1, x_o.shape[0], x_o.shape[1])
posterior.posterior_estimator._condition_shape = x_o.shape
posterior.sample((n_samples,), x=x_o, show_progress_bars=False)

Describe the solution you'd like

Shape checks should only be enforced where a static shape is truly necessary. Specifically:

  • If an embedding network is used, shape checks should apply to the output of the embedding network.
  • Eliminate redundant shape checks to avoid unnecessary constraints on variable-length inputs.
@manuelgloeckler manuelgloeckler added enhancement New feature or request hackathon labels Dec 10, 2024
@janfb
Copy link
Contributor

janfb commented Dec 13, 2024

related to #218

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request hackathon
Projects
None yet
Development

No branches or pull requests

2 participants