Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary device_put #1952

Merged
merged 1 commit into from
Jan 17, 2025
Merged

Remove unnecessary device_put #1952

merged 1 commit into from
Jan 17, 2025

Conversation

fehiepsi
Copy link
Member

Fixes #1949

All tests passed after the change.

@martinjankowiak
Copy link
Collaborator

do we need to test on cuda to be sure?

@fehiepsi
Copy link
Member Author

I think originally, we used device_put to move scalars to devices, to avoid recompiling issues. But it seems unnecessary for recent jax releases (if compiling tests pass).

@fehiepsi fehiepsi merged commit 4704656 into pyro-ppl:master Jan 17, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Potential random key reuse
2 participants