Skip to content

Commit

Permalink
Attribution for custom jvp approach
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Jan 26, 2025
1 parent 8c9ca83 commit dbf189e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,12 @@ def log1mexp(x: ArrayLike) -> ArrayLike:
)


# custom jvp for log1mexp to handle
# the gradient when x is near 0.
# Custom jvp for log1mexp to handle the gradient when x is near 0.
#
# Inspired by the approach taken here for the function log1mexp(-x):
# https://github.com/google-research/google-research/blob/14e984cdb8630a7e3d210dff8760fc06d490fc4b/diffusion_distillation/diffusion_distillation/utils.py#L364-L370
# That code is (c) 2024 The Google Research Authors and licensed under
# an Apache 2.0 License.
log1mexp.defjvps(lambda t, ans, x: -t / jnp.expm1(-x))


Expand Down

0 comments on commit dbf189e

Please sign in to comment.