Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improve custom_vjp ergonomics #4489

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jan 19, 2025

What does this PR do?

Updates the signature of the bwd function to

(*inputs, residual, output_gradient) -> inputs_tangent

Where each element in *inputs is either:

  • The exact input value if it was declared as non differentiable in nondiff_argnums.
  • A State object representing the gradient of state updates if the input is a graph node.
  • None for all (possibly nested) JAX Arrays.

Previously output_gradient contained a (input_updates_gradient, true_output_gradient) tuple, and *inputs where only present if marked as nondiff_argnums. This change makes the output_gradient compatible with JAX while making *inputs JAX incompatible instead.

NOTE: consider following the Linen strategy of using vjp to compute the gradient wrt the desired params instead of always providing it. We would need to add nnx.vjp for this.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 5 times, most recently from 8c13c94 to dd5755a Compare January 25, 2025 02:27
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 6 times, most recently from 7aea7d9 to 48c59d3 Compare February 4, 2025 18:09
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 2 times, most recently from ed4c6a0 to 19f78a6 Compare February 13, 2025 07:18
@cgarciae cgarciae force-pushed the nnx-refactor-custom-vjp branch from 1b7bab3 to 8510c16 Compare March 5, 2025 17:11
@cgarciae cgarciae changed the base branch from nnx-optimize-fingerprint to main March 5, 2025 17:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant