chore(deps): update dependency jaxlib to v0.5.0 #590
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
0.4.23
->0.5.0
Release Notes
jax-ml/jax (jaxlib)
v0.5.0
As of this release, JAX now uses
effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.
Breaking changes
Enable
jax_threefry_partitionable
by default (seethe update note).
This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussihttps://github.com/jax-ml/jax/discussions/22936iscussions/22936.
Two key factors motivated this decision:
would prefer to ship no release than a broken release.
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again.
Changes:
supported version until June 2025.
supported version until June 2025.
jax.numpy.einsum
now defaults tooptimize='auto'
rather thanoptimize='optimal'
. This avoids exponentially-scaling trace-time inthe case of many arguments ({jax-issue}
#25214
).jax.numpy.linalg.solve
no longer supports batched 1D argumentson the right hand side. To recover the previous behavior in these cases,
use
solve(a, b[..., None]).squeeze(-1)
.New Features
jax.numpy.fft.fftn
, {func}jax.numpy.fft.rfftn
,{func}
jax.numpy.fft.ifftn
, and {func}jax.numpy.fft.irfftn
now supporttransforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}
#25606
for more details.{func}
jax.ffi.register_ffi_type_id
function..as_text()
method now supports thedebug_info
optionto include debugging information, e.g., source location, in the output.
Deprecations
jax.interpreters.xla
,abstractify
andpytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name
in {mod}
jax.core
.jax.scipy.special.lpmn
and {func}jax.scipy.special.lpmn_values
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
jax.extend.ffi
submodule was moved to {mod}jax.ffi
, and theprevious import path is deprecated.
Deletions
jax_enable_memories
flag has been deleted and the behavior of that flagis on by default.
jax.lib.xla_client
, the previously-deprecatedDevice
andXlaRuntimeError
symbols have been removed; instead usejax.Device
and
jax.errors.JaxRuntimeError
respectively.jax.experimental.array_api
module has been removed after beingdeprecated in JAX v0.4.32. Since that release, {mod}
jax.numpy
supportsthe array API directly.
v0.4.38
Changes:
jax.tree.flatten_with_path
andjax.tree.map_with_path
are addedas shortcuts of the corresponding
tree_util
functions.Deprecations
jax.core
namespace have been deprecated.Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}
jax.extend.core
; see the documentation for {mod}jax.extend
for information on the compatibility guarantees of these semi-public extensions.
jax.core
:check_eqn
,check_type
,check_valid_jaxtype
, andnon_negative_dim
.jax.lib.xla_bridge
:xla_client
anddefault_backend
.jax.lib.xla_client
:_xla
andbfloat16
.jax.numpy
:round_
.New Features
jax.export.export
can be used for device-polymorphic export withshardings constructed with {func}
jax.sharding.AbstractMesh
.See the jax.export documentation.
jax.lax.split
. This is a primitive version of{func}
jax.numpy.split
, added because it yields a more compacttranspose during automatic differentiation.
v0.4.36
Breaking Changes
This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,
post_process_call
,new_base_main
,custom_bind
, and so on. The change should only affectusers that use JAX internals.
If you do use JAX internals then you may need to
update your code (see
jax-ml/jax@c36e1f7
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallback
flag as a workaround, and ifyou need help updating your code then please file a bug.
{func}
jax.experimental.jax2tf.convert
withnative_serialization=False
or with
enable_xla=False
have been deprecated since July 2024, withJAX version 0.4.31. Now we removed support for these use cases.
jax2tf
with native serialization will still be supported.
In
jax.interpreters.xla
, thexb
,xc
, andxe
symbols have been removedafter being deprecated in JAX v0.4.31. Instead use
xb = jax.lib.xla_bridge
,xc = jax.lib.xla_client
, andxe = jax.lib.xla_extension
.The deprecated module
jax.experimental.export
has been removed. It was replacedby {mod}
jax.export
in JAX v0.4.30. See the migration guidefor information on migrating to the new API.
The
initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
has been removed, after being deprecated in v0.4.27.
Calling
np.asarray
on typed PRNG keys (i.e. keys produced by :func:jax.random.key
)now raises an error. Previously, this returned a scalar object array.
The following deprecated methods and functions in {mod}
jax.export
havebeen removed:
jax.export.DisabledSafetyCheck.shape_assertions
: it had no effectalready.
jax.export.Exported.lowering_platforms
: useplatforms
.jax.export.Exported.mlir_module_serialization_version
:use
calling_convention_version
.jax.export.Exported.uses_shape_polymorphism
:use
uses_global_constants
.lowering_platforms
kwarg for {func}jax.export.export
: useplatforms
instead.The kwargs
symbolic_scope
andsymbolic_constraints
from{func}
jax.export.symbolic_args_specs
have been removed. They weredeprecated in June 2024. Use
scope
andconstraints
instead.Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a
TypeError
.Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run
python build/build.py --help
formore details. Brief overview of the new subcommand options:
build
: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.{func}
jax.scipy.linalg.toeplitz
now does implicit batching on multi-dimensionalinputs. To recover the previous behavior, you can call {func}
jax.numpy.ravel
on the function inputs.
{func}
jax.scipy.special.gamma
and {func}jax.scipy.special.gammasgn
nowreturn NaN for negative integer inputs, to match the behavior of Scihttps://github.com/scipy/scipy/pull/21827scipy/pull/21827.
jax.clear_backends
was removed after being deprecated in v0.4.26.We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the
disabled_checks
parameter. See more details in the documentation.
New Features
jax.jit
got a newcompiler_options: dict[str, Any]
argument, forpassing compilation options to XLA. For the moment it's undocumented and
may be in flux.
jax.tree_util.register_dataclass
now allows metadata fields to bedeclared inline via {func}
dataclasses.field
. See the function documentationfor examples.
jax.numpy.put_along_axis
.jax.lax.linalg.eig
and the relatedjax.numpy
functions({func}
jax.numpy.linalg.eig
and {func}jax.numpy.linalg.eigvals
) are nowsupported on GPU. See {jax-issue}
#24663
for more details.jax_exec_time_optimization_effort
andjax_memory_fitting_effort
, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.Bug fixes
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}
#24843
for more details.Deprecations
jax.lib.xla_extension.ArrayImpl
andjax.lib.xla_client.ArrayImpl
are deprecated;use
jax.Array
instead.jax.lib.xla_extension.XlaRuntimeError
is deprecated; usejax.errors.JaxRuntimeError
instead.
v0.4.35
Breaking Changes
jax.numpy.isscalar
now returns True for any array-like object withzero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.
jax.experimental.host_callback
has been deprecated since March 2024, withJAX version 0.4.26. Now we removed it.
See {jax-issue}
#20385
for a discussion of alternatives.Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFToperations. The semi-public API
jax.lib.xla_client.FftType
has beendeprecated.
libtpu
package rather thanlibtpu-nightly
. For the next few releases JAX will pin an empty version oflibtpu-nightly
as well aslibtpu
to ease the transition; that dependencywill be removed in Q1 2025.
Deprecations:
jax.lib.xla_client.PaddingType
has been deprecated.No JAX APIs consume this type, so there is no replacement.
jax.pure_callback
and{func}
jax.extend.ffi.ffi_call
undervmap
has been deprecated and so hasthe
vectorized
parameter to those functions. Thevmap_method
parametershould be used instead for better defined behavior. See the discussion in
{jax-issue}
#23881
for more details.jax.lib.xla_client.register_custom_call_target
hasbeen deprecated. Use the JAX FFI instead.
jax.lib.xla_client.dtype_to_etype
,jax.lib.xla_client.ops
,jax.lib.xla_client.shape_from_pyval
,jax.lib.xla_client.PrimitiveType
,jax.lib.xla_client.Shape
,jax.lib.xla_client.XlaBuilder
, andjax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLOinstead.
v0.4.34
New Functionality
supported.
jax.errors.JaxRuntimeError
has been added as a public alias for theformerly private
XlaRuntimeError
type.Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.instead).
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, withJAX version 0.4.26. Now we set the default value of the
--jax_host_callback_legacy
configuration value toTrue
, which means thatif your code uses
jax.experimental.host_callback
APIs, those API callswill be implemented in terms of the new
jax.experimental.io_callback
API.If this breaks your code, for a very limited time, you can set the
--jax_host_callback_legacy
toTrue
. Soon we will remove thatconfiguration option, so you should instead transition to using the
new JAX callback APIs. See {jax-issue}
#20385
for a discussion.Deprecations
jax.numpy.trim_zeros
, non-arraylike arguments or arraylikearguments with
ndim != 1
are now deprecated, and in the future will resultin an error.
jax.core.pp_*
have been removed, afterbeing deprecated in JAX v0.4.30.
jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Usejax.errors.JaxRuntimeError
instead.Deletion:
jax.xla_computation
is deleted. It's been 3 months since it's deprecationin 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as
jax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
..out_info
property ofjax.stages.Lowered
to get theoutput information (like tree structure, shape and dtype).
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument.The argument was only used by
xmap
which was removed in 0.4.31.jax.tree.map(f, None, non-None)
, which previously emitted aDeprecationWarning
, now raises an error in a future version of jax.None
is only a tree-prefix of itself. To preserve the current behavior, you can
ask
jax.tree.map
to treatNone
as a leaf value by writing:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please usejax.sharding.Sharding
.Bug fixes
jax.numpy.cumsum
would produce incorrect outputsif a non-boolean input was provided and
dtype=bool
was specified.jax.numpy.ldexp
to get correct gradient.v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of
libtpu
.This release fixes an inaccurate result for F64 tanh on CPU (#23590).
v0.4.32
Compare Source
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.
New Functionality
jax.extend.ffi.ffi_call
and {func}jax.extend.ffi.ffi_lowering
to support the use of the new {ref}
ffi-tutorial
to interface with customC++ and CUDA code from JAX.
Changes
jax_enable_memories
flag is set toTrue
by default.jax.numpy
now supports v2023.12 of the Python Array API Standard.See {ref}
python-array-api
for more information.more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
jax.config.update('jax_cpu_enable_async_dispatch', False)
.jax.process_indices
function to replace thejax.host_ids()
function that was deprecated in JAX v0.2.13.numpy.fabs
,jax.numpy.fabs
has beenmodified to no longer support
complex dtypes
.jax.tree_util.register_dataclass
now checks thatdata_fields
and
meta_fields
includes all dataclass fields withinit=True
and only them, if
nodetype
is a dataclass.jax.numpy
functions now have full {class}~jax.numpy.ufunc
interfaces, including {obj}
~jax.numpy.add
, {obj}~jax.numpy.multiply
,{obj}
~jax.numpy.bitwise_and
, {obj}~jax.numpy.bitwise_or
,{obj}
~jax.numpy.bitwise_xor
, {obj}~jax.numpy.logical_and
,{obj}
~jax.numpy.logical_and
, and {obj}~jax.numpy.logical_and
.In future releases we plan to expand these to other ufuncs.
jax.lax.optimization_barrier
, which allows users to preventcompiler optimizations such as common-subexpression elimination and to
control scheduling.
Breaking changes
jax.extend.mlir.mhlo
) has been removed. Use thestablehlo
dialect instead.Deprecations
jax.numpy.clip
and {func}jax.numpy.hypot
areno longer allowed, after being deprecated since JAX v0.4.27.
jax.lib.xla_bridge.xla_client
: use {mod}jax.lib.xla_client
directly.jax.lib.xla_bridge.get_backend
: use {func}jax.extend.backend.get_backend
.jax.lib.xla_bridge.default_backend
: use {func}jax.extend.backend.default_backend
.jax.experimental.array_api
module is deprecated, and importing it is nolonger required to use the Array API.
jax.numpy
supports the array APIdirectly; see {ref}
python-array-api
for more information.jax.core.check_eqn
,jax.core.check_type
, andjax.core.check_valid_jaxtype
are now deprecated, and will be removed inthe future.
jax.numpy.round_
has been deprecated, following removal of the correspondingAPI in NumPy 2.0. Use {func}
jax.numpy.round
instead.jax.dlpack.from_dlpack
is deprecated.The argument to {func}
jax.dlpack.from_dlpack
should be an array fromanother framework that implements the
__dlpack__
protocol.v0.4.31
Compare Source
Deletion
shard_map
as the replacement.Changes
but we now declare this version constraint formally.
supported version until July 2025.
supported version until December 2024.
supported version until January 2025.
jax.numpy.ceil
, {func}jax.numpy.floor
and {func}jax.numpy.trunc
now return the outputof the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
libdevice.10.bc
is no longer bundled with CUDA wheels. It must beinstalled either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
jax.experimental.pallas.BlockSpec
now expectsblock_shape
tobe passed before
index_map
. The old argument order is deprecated andwill be removed in a future release.
with TPUs/CPUs. For example,
cuda(id=0)
will now beCudaDevice(id=0)
.device
property andto_device
method to {class}jax.Array
, aspart of JAX's Array API support.
Deprecations
polymorphic shapes. From {mod}
jax.core
: removedcanonicalize_shape
,dimension_as_value
,definitely_equal
, andsymbolic_equal_dim
.Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
jax.experimental.jax2tf.convert
withnative_serialization=False
or
enable_xla=False
is now deprecated and this support will be removed ina future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
jax.random.shuffle
has been removed;instead use
jax.random.permutation
withindependent=True
.v0.4.30
Compare Source
Changes
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
jax.experimental.mesh_utils
can now create an efficient mesh for TPU v5e.plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with
pip install jax
, no extras required.to exist in
jax.experimental.export
(which is being deprecated),and will now live in
jax.export
.See the documentation.
Deprecations
jax.core.pp_*
are deprecated, and will be removedin a future release.
TypeError
in a future JAXrelease. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
jax.experimental.export
is deprecated. Use {mod}jax.export
instead.See the migration guide.
x
andy
,x.astype(y)
will raise a warning. To silence it usex.astype(y.dtype)
.jax.xla_computation
is deprecated and will be removed in a future release.Please use the AOT APIs to get the same functionality as
jax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
..out_info
property ofjax.stages.Lowered
to get theoutput information (like tree structure, shape and dtype).
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.v0.4.29
Compare Source
Changes
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
pip install jax[cuda12]
).jax.experimental.export
API. It is not possible anymore to usefrom jax.experimental.export import export
, and instead you should usefrom jax.experimental import export
.The removed functionality has been deprecated since 0.4.24.
is_leaf
argument to {func}jax.tree.all
& {func}jax.tree_util.tree_all
.Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please usejax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed asjax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.The old names will be removed after 3 months.
jax.core
:non_negative_dim
,DimSize
,Shape
jax.lax
:tie_in
jax.nn
:normalize
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,XlaOp
.tol
argument of {func}jax.numpy.linalg.matrix_rank
is beingdeprecated and will soon be removed. Use
rtol
instead.rcond
argument of {func}jax.numpy.linalg.pinv
is beingdeprecated and will soon be removed. Use
rtol
instead.jax.config
submodule has been removed. To configure JAXuse
import jax
and then reference the config object viajax.config
.jax.random
APIs no longer accept batched keys, where previouslysome did unintentionally. Going forward, we recommend explicit use of
{func}
jax.vmap
in such cases.jax.scipy.special.beta
, thex
andy
parameters have beenrenamed to
a
andb
for consistency with otherbeta
APIs.New Functionality
jax.experimental.Exported.in_shardings_jax
to constructshardings that can be used with the JAX APIs from the HloShardings
that are stored in the
Exported
objects.v0.4.28
Compare Source
Bug fixes
make_jaxpr
that was breaking Equinox (#21116).Deprecations & removals
kind
argument to {func}jax.numpy.sort
and {func}jax.numpy.argsort
is now removed. Use
stable=True
orstable=False
instead.get_compute_capability
from thejax.experimental.pallas.gpu
module. Use the
compute_capability
attribute of a GPU device, returnedby {func}
jax.devices
or {func}jax.local_devices
, instead.newshape
argument to {func}jax.numpy.reshape
is being deprecatedand will soon be removed. Use
shape
instead.Changes
v0.4.27
Compare Source
New Functionality
jax.numpy.unstack
and {func}jax.numpy.cumulative_sum
,following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
jax_cpu_collectives_implementation
to select theimplementation of cross-process collective operations used by the CPU backend.
Choices available are
'none'
(default),'gloo'
and'mpi'
(requires jaxlib 0.4.26).If set to
'none'
, cross-process collective operations are disabled.Changes
jax.pure_callback
, {func}jax.experimental.io_callback
and {func}
jax.debug.callback
now use {class}jax.Array
insteadof {class}
np.ndarray
. You can recover the old behavior by transformingthe arguments via
jax.tree.map(np.asarray, args)
before passing themto the callback.
complex_arr.astype(bool)
now follows the same semantics as NumPy, returningFalse where
complex_arr
is equal to0 + 0j
, and True otherwise.core.Token
now is a non-trivial class which wraps ajax.Array
. It couldbe created and threaded in and out of computations to build up dependency.
The singleton object
core.token
has been removed, users now should createand use fresh
core.Token
objects instead.by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
. If the newdefault causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
Deprecations & Removals
lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA
environment variable no longer has any effect.jax.numpy.clip
has a new argument signature:a
,a_min
, anda_max
are deprecated in favor ofx
(positional only),min
, andmax
({jax-issue}20550
).device()
method of JAX arrays has been removed, after being deprecatedsince JAX v0.4.21. Use
arr.devices()
instead.initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
is deprecated; empty inputs to softmax are now supported without setting this.
jax.jit
, passing invalidstatic_argnums
orstatic_argnames
now leads to an error rather than a warning.
jax.numpy.hypot
function now issues a deprecation warning whenpassing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
jax.numpy.nonzero
, {func}jax.numpy.where
, andrelated functions now raise an error, following a similar change in NumPy.
jax_cpu_enable_gloo_collectives
is deprecated.Use
jax.config.update('jax_cpu_collectives_implementation', 'gloo')
instead.jax.Array.device_buffer
andjax.Array.device_buffers
methods havebeen removed after being deprecated in JAX v0.4.22. Instead use
{attr}
jax.Array.addressable_shards
and {meth}jax.Array.addressable_data
.condition
,x
, andy
parameters ofjax.numpy.where
are nowpositional-only, following deprecation of the keywords in JAX v0.4.21.
jax.lax.linalg
now must bespecified by keyword. Previously, this raised a DeprecationWarning.
jax.numpy
APIs,including {func}
~jax.numpy.apply_along_axis
,{func}
~jax.numpy.apply_over_axes
, {func}~jax.numpy.inner
,{func}
~jax.numpy.outer
, {func}~jax.numpy.cross
,{func}
~jax.numpy.kron
, and {func}~jax.numpy.lexsort
.Bug fixes
jax.numpy.astype
will now always return a copy whencopy=True
.Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to
copy=False
to preserve backwards compatibility.v0.4.26
Compare Source
New Functionality
jax.numpy.trapezoid
, following the addition of this function inNumPy 2.0.
Changes
jax.numpy.geomspace
now chooses the logarithmic spiralbranch consistent with that of NumPy 2.0.
lax.rng_bit_generator
, and in turn the'rbg'
and
'unsafe_rbg'
PRNG implementations, underjax.vmap
haschanged so that
mapping over keys results in random generation only from the first
key in the batch.
jax.random.key
for construction of PRNG key arraysrather than
jax.random.PRNGKey
.Deprecations & Removals
jax.tree_map
is deprecated; usejax.tree.map
instead, or for backwardcompatibility with older JAX versions, use {func}
jax.tree_util.tree_map
.jax.clear_backends
is deprecated as it does not necessarily do whatits name suggests and can lead to unexpected consequences, e.g., it will not
destroy existing backends and release corresponding owned resources. Use
{func}
jax.clear_caches
if you only want to clean up compilation caches.For backward compatibility or you really need to switch/reinitialize the
default backend, use {func}
jax.extend.backend.clear_backends
.jax.experimental.maps
module andjax.experimental.maps.xmap
aredeprecated. Use
jax.experimental.shard_map
orjax.vmap
with thespmd_axis_name
argument for expressing SPMD device-parallel computations.jax.experimental.host_callback
module is deprecated.Use instead the new JAX external callbacks.
Added
JAX_HOST_CALLBACK_LEGACY
flag to assist in the transition to thenew callbacks. See {jax-issue}
#20385
for a discussion.jax.numpy.array_equal
and {func}jax.numpy.array_equiv
that cannot be converted to a JAX array now results in an exception.
jax_parallel_functions_output_gda
has been removed.This flag was long deprecated and did nothing; its use was a no-op.
jax.interpreters.ad.config
andjax.interpreters.ad.source_info_util
have now been removed. Usejax.config
and
jax.extend.source_info_util
instead.has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
See a description of the versions.
This change could break clients that set a specific
JAX serialization version lower than 9.
v0.4.25
Compare Source
New Features
Interface
import support (requires jaxlib 0.4.24).
x[True]
orx[False]
.jax.tree
module, with a more convenient interface for referencing functionsin {mod}
jax.tree_util
.jax.tree.transpose
(i.e. {func}jax.tree_util.tree_transpose
) now acceptsinner_treedef=None
, in which case the inner treedef will be automatically inferred.Changes
kernels. You can revert to the old behavior by setting the
JAX_TRITON_COMPILE_VIA_XLA
environment variable to"0"
.jax.interpreters.xla
that were removed in v0.4.24have been re-added in v0.4.25, including
backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
, andXLAOp
. These are still considered deprecated, andwill be removed again in the future when better replacements are available.
Refer to {jax-issue}
#19816
for discussion.Deprecations & Removals
jax.numpy.linalg.solve
now shows a deprecation warning for batched 1Dsolves with
b.ndim > 1
. In the future these will be treated as batched 2Dsolves.
of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
following a standard 3 months deprecation cycle (see {ref}
api-compatibility
).These include
jax.config.config
object anddefine_*_state
andDEFINE_*
methods of {data}jax.config
.jax.config
submodule viaimport jax.config
is deprecated.To configure JAX use
import jax
and then reference the config objectvia
jax.config
.v0.4.24
Compare Source
Changes
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to
rule
parameter ofmlir.register_lowering
then add yourprimitive to
jax._src.dispatch.prim_requires_devices_during_lowering
set.This is needed because custom_partitioning and JAX callbacks need physical
devices to create
Sharding
s during lowering.This is a temporary state until we can create
Sharding
s without physicaldevices.
jax.numpy.argsort
and {func}jax.numpy.sort
now support thestable
and
descending
arguments.{mod}
jax.experimental.jax2tf
and {mod}jax.experimental.export
):#19227
)This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
#19235
) we nowconsider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
cannot interact, e.g., in arithmetic operations.
Scopes are introduced by {func}
jax.experimental.jax2tf.convert
,{func}
jax.experimental.export.symbolic_shape
, {func}jax.experimental.export.symbolic_args_specs
.The scope of a symbolic expression
e
can be read withe.scope
and passedinto the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
to be equal if the normalized form of their difference reduces to 0
({jax-issue}
#19231
; note that this may result in user-visible behaviorchanges)
({jax-issue}
#19235
).core.non_negative_dim
API (introduced recently)was deprecated and
core.max_dim
andcore.min_dim
were introduced({jax-issue}
#18953
) to expressmax
andmin
for symbolic dimensions.You can use
core.max_dim(d, 0)
instead ofcore.non_negative_dim(d)
.shape_poly.is_poly_dim
is deprecated in favor ofexport.is_symbolic_dim
({jax-issue}
#19282
).export.args_specs
is deprecated in favor ofexport.symbolic_args_specs ({jax-issue}
#19283`).shape_poly.PolyShape
andjax2tf.PolyShape
are deprecated, usestrings for polymorphic shapes specifications ({jax-issue}
#19284
).for {mod}
jax.experimental.jax2tf
and {mod}jax.experimental.export
.See description of version numbers.
jax.experimental.export
. Instead offrom jax.experimental.export import export
you should use nowfrom jax.experimental import export
. The old way of importing willcontinue to work for a deprecation period of 3 months.
jax.scipy.stats.sem
.jax.numpy.unique
withreturn_inverse = True
returns inverse indicesreshaped to the dimension of the input, following a similar change to
{func}
numpy.unique
in NumPy 2.0.jax.numpy.sign
now returnsx / abs(x)
for nonzero complex inputs. This isconsistent with the behavior of {func}
numpy.sign
in NumPy version 2.0.jax.scipy.special.logsumexp
withreturn_sign=True
now uses the NumPy 2.0convention for the complex sign,
x / abs(x)
. This is consistent with the behaviorof {func}
scipy.special.logsumexp
in SciPy v1.13.Previously bool values could not be imported and were exported as integers.
Deprecations & Removals
standard 3+ month deprecation cycle (see {ref}
api-compatibility
).This includes:
jax.core
:TracerArrayConversionError
,TracerIntegerConversionError
,UnexpectedTracerError
,as_hashable_function
,collections
,dtypes
,lu
,map
,namedtuple
,partial
,pp
,ref
,safe_zip
,safe_map
,source_info_util
,total_ordering
,traceback_util
,tuple_delete
,tuple_insert
, andzip
.jax.lax
:dtypes
,itertools
,naryop
,naryop_dtype_rule
,standard_abstract_eval
,standard_naryop
,standard_primitive
,standard_unop
,unop
, andunop_dtype_rule
.jax.linear_util
submodule and all its contents.jax.prng
submodule and all its contents.jax.random
:PRNGKeyArray
,KeyArray
,default_prng_impl
,threefry_2x32
,threefry2x32_key
,threefry2x32_p
,rbg_key
, andunsafe_rbg_key
.jax.tree_util
:register_keypaths
,AttributeKeyPathEntry
, andGetItemKeyPathEntry
.jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,axis_groups
,ShapedArray
,ConcreteArray
,AxisEnv
,backend_compile
,and
XLAOp
.jax.numpy
:NINF
,NZERO
,PZERO
,row_stack
,issubsctype
,trapz
, andin1d
.jax.scipy.linalg
:tril
andtriu
.PRNGKeyArray.unsafe_raw_array
has beenremoved. Use {func}
jax.random.key_data
instead.bool(empty_array)
now raises an error rather than returningFalse
. Thispreviously raised a deprecation warning, and follows a similar change in NumPy.
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
removed in the future. Use the "stablehlo" dialect instead.
jax.random
: passing batched keys directly to random number generation functions,such as {func}
~jax.random.bits
, {func}~jax.random.gamma
, and others, is deprecatedand will emit a
FutureWarning
. Usejax.vmap
for explicit batching.jax.lax.tie_in
is deprecated: it has been a no-op since JAX v0.2.0.Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR was generated by Mend Renovate. View the repository job log.