Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy reductions: avoid upcast of f16 when dtype is specified by user #26403

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can drop parens around dtype is None here and elsewhere.

axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
Expand Down Expand Up @@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)

Expand Down Expand Up @@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[6. ]], dtype=float32)
"""
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
where=where, upcast_f16_for_computation=(dtype is None))

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
upcast_f16_for_computation: bool = True,
Expand Down
52 changes: 51 additions & 1 deletion tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,12 @@ def np_fun(x):
np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3,
np.float64: 1e-5, np.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
if out_dtype in [np.float16, dtypes.bfloat16]:
# For 16-bit out_type, NumPy will accumulate in float32, while JAX
# accumulates in 16-bit, so we need a larger tolerance.
tol = 1e-1
else:
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
Expand Down Expand Up @@ -930,5 +935,50 @@ def np_op(x, axis=None, dtype=None, include_initial=False):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
op=['sum', 'prod'],
dtype=['float16', 'bfloat16'],
)
def testReducerF16Casts(self, op, dtype):
rng = jtu.rand_default(self.rng())
x = jnp.asarray(rng((10,), dtype))

func = getattr(jnp, op)
reduce_p = getattr(jax.lax, f"reduce_{op}_p")
conv_elem_p = jax.lax.convert_element_type_p

# Without dtype specified, the reduction is sandwiched between two casts.
jaxpr1 = jax.make_jaxpr(func)(x)
self.assertEqual(
[eqn.primitive for eqn in jaxpr1.eqns],
[conv_elem_p, reduce_p, conv_elem_p])

# With dtype specified, the reduction happens without a cast.
jaxpr2 = jax.make_jaxpr(partial(func, dtype=dtype))(x)
self.assertEqual([eqn.primitive for eqn in jaxpr2.eqns], [reduce_p])

@jtu.sample_product(
dtype=['float16', 'bfloat16'],
)
def testMeanF16Casts(self, dtype):
rng = jtu.rand_default(self.rng())
x = jnp.asarray(rng((10,), dtype))

reduce_sum_p = jax.lax.reduce_sum_p
div_p = jax.lax.div_p
conv_elem_p = jax.lax.convert_element_type_p

# Without dtype specified, the reduction is sandwiched between two casts.
jaxpr1 = jax.make_jaxpr(jnp.mean)(x)
self.assertEqual(
[eqn.primitive for eqn in jaxpr1.eqns],
[conv_elem_p, reduce_sum_p, div_p, conv_elem_p])

# With dtype specified, the reduction happens without a cast.
jaxpr2 = jax.make_jaxpr(partial(jnp.mean, dtype=dtype))(x)
self.assertEqual(
[eqn.primitive for eqn in jaxpr2.eqns],
[reduce_sum_p, div_p])

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading