diff --git a/.bumpversion.toml b/.bumpversion.toml index 5865569..3b0a8b1 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v0.5.2" +current_version = "v0.5.3" commit = true commit_args = "--no-verify" tag = true diff --git a/pyproject.toml b/pyproject.toml index 7ef1bcc..4dce3e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "fmmax" -version = "v0.5.2" +version = "v0.5.3" description = "Fourier modal method with Jax" readme = "README.md" requires-python = ">=3.7" diff --git a/src/fmmax/__init__.py b/src/fmmax/__init__.py index 68225e5..c64ea25 100644 --- a/src/fmmax/__init__.py +++ b/src/fmmax/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -__version__ = "v0.5.2" +__version__ = "v0.5.3" from . import ( basis, diff --git a/src/fmmax/vector.py b/src/fmmax/vector.py index 473c27a..504d91a 100644 --- a/src/fmmax/vector.py +++ b/src/fmmax/vector.py @@ -144,6 +144,7 @@ def compute_tangent_field( Returns: The normal field, `(tx, ty)`. """ + arr = jax.lax.stop_gradient(arr) batch_shape = arr.shape[:-2] arr = utils.atleast_nd(arr, n=3) arr = arr.reshape((-1,) + arr.shape[-2:])