Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(deps): update dependency jax to v0.5.0 #589

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

renovate[bot]
Copy link
Contributor

@renovate renovate bot commented Jul 14, 2024

This PR contains the following updates:

Package Change Age Adoption Passing Confidence
jax 0.4.23 -> 0.5.0 age adoption passing confidence

Release Notes

jax-ml/jax (jax)

v0.5.0

Compare Source

As of this release, JAX now uses
effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.

  • Breaking changes

    • Enable jax_threefry_partitionable by default (see
      the update note).

    • This release drops support for Mac x86 wheels. Mac ARM of course remains
      supported. For a recent discussihttps://github.com/jax-ml/jax/discussions/22936iscussions/22936.

      Two key factors motivated this decision:

      • The Mac x86 build (only) has a number of test failures and crashes. We
        would prefer to ship no release than a broken release.
      • Mac x86 hardware is end-of-life and cannot be easily obtained for
        developers at this point. So it is difficult for us to fix this kind of
        problem even if we wanted to.

      We are open to readding support for Mac x86 if the community is willing
      to help support that platform: in particular, we would need the JAX test
      suite to pass cleanly on Mac x86 before we could ship releases again.

  • Changes:

    • The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
      supported version until June 2025.
    • The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
      supported version until June 2025.
    • {func}jax.numpy.einsum now defaults to optimize='auto' rather than
      optimize='optimal'. This avoids exponentially-scaling trace-time in
      the case of many arguments ({jax-issue}#25214).
    • {func}jax.numpy.linalg.solve no longer supports batched 1D arguments
      on the right hand side. To recover the previous behavior in these cases,
      use solve(a, b[..., None]).squeeze(-1).
  • New Features

    • {func}jax.numpy.fft.fftn, {func}jax.numpy.fft.rfftn,
      {func}jax.numpy.fft.ifftn, and {func}jax.numpy.fft.irfftn now support
      transforms in more than 3 dimensions, which was previously the limit. See
      {jax-issue}#25606 for more details.
    • Support added for user defined state in the FFI via the new
      {func}jax.ffi.register_ffi_type_id function.
    • The AOT lowering .as_text() method now supports the debug_info option
      to include debugging information, e.g., source location, in the output.
  • Deprecations

    • From {mod}jax.interpreters.xla, abstractify and pytype_aval_mappings
      are now deprecated, having been replaced by symbols of the same name
      in {mod}jax.core.
    • {func}jax.scipy.special.lpmn and {func}jax.scipy.special.lpmn_values
      are deprecated, following their deprecation in SciPy v1.15.0. There are
      no plans to replace these deprecated functions with new APIs.
    • The {mod}jax.extend.ffi submodule was moved to {mod}jax.ffi, and the
      previous import path is deprecated.
  • Deletions

    • jax_enable_memories flag has been deleted and the behavior of that flag
      is on by default.
    • From jax.lib.xla_client, the previously-deprecated Device and
      XlaRuntimeError symbols have been removed; instead use jax.Device
      and jax.errors.JaxRuntimeError respectively.
    • The jax.experimental.array_api module has been removed after being
      deprecated in JAX v0.4.32. Since that release, {mod}jax.numpy supports
      the array API directly.

v0.4.38

Compare Source

  • Changes:

    • jax.tree.flatten_with_path and jax.tree.map_with_path are added
      as shortcuts of the corresponding tree_util functions.
  • Deprecations

    • a number of APIs in the internal jax.core namespace have been deprecated.
      Most were no-ops, were little-used, or can be replaced by APIs of the same
      name in {mod}jax.extend.core; see the documentation for {mod}jax.extend
      for information on the compatibility guarantees of these semi-public extensions.
    • Several previously-deprecated APIs have been removed, including:
      • from {mod}jax.core: check_eqn, check_type, check_valid_jaxtype, and
        non_negative_dim.
      • from {mod}jax.lib.xla_bridge: xla_client and default_backend.
      • from {mod}jax.lib.xla_client: _xla and bfloat16.
      • from {mod}jax.numpy: round_.
  • New Features

    • {func}jax.export.export can be used for device-polymorphic export with
      shardings constructed with {func}jax.sharding.AbstractMesh.
      See the jax.export documentation.
    • Added {func}jax.lax.split. This is a primitive version of
      {func}jax.numpy.split, added because it yields a more compact
      transpose during automatic differentiation.

v0.4.37

Compare Source

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug where jit would error if an argument was named f (#​25329).
    • Fix a bug that will throw index out of range error in
      {func}jax.lax.while_loop if the user register pytree node class with
      different aux data for the flatten and flatten_with_path.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.

v0.4.36

Compare Source

  • Breaking Changes

    • This release lands "stackless", an internal change to JAX's tracing
      machinery. We made trace dispatch purely a function of context rather than a
      function of both context and data. This let us delete a lot of machinery for
      managing data-dependent tracing: levels, sublevels, post_process_call,
      new_base_main, custom_bind, and so on. The change should only affect
      users that use JAX internals.

      If you do use JAX internals then you may need to
      update your code (see
      jax-ml/jax@c36e1f7
      for clues about how to do this). There might also be version skew
      issues with JAX libraries that do this. If you find this change breaks your
      non-JAX-internals-using code then try the
      config.jax_data_dependent_tracing_fallback flag as a workaround, and if
      you need help updating your code then please file a bug.

    • {func}jax.experimental.jax2tf.convert with native_serialization=False
      or with enable_xla=False have been deprecated since July 2024, with
      JAX version 0.4.31. Now we removed support for these use cases. jax2tf
      with native serialization will still be supported.

    • In jax.interpreters.xla, the xb, xc, and xe symbols have been removed
      after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge,
      xc = jax.lib.xla_client, and xe = jax.lib.xla_extension.

    • The deprecated module jax.experimental.export has been removed. It was replaced
      by {mod}jax.export in JAX v0.4.30. See the migration guide
      for information on migrating to the new API.

    • The initial argument to {func}jax.nn.softmax and {func}jax.nn.log_softmax
      has been removed, after being deprecated in v0.4.27.

    • Calling np.asarray on typed PRNG keys (i.e. keys produced by :func:jax.random.key)
      now raises an error. Previously, this returned a scalar object array.

    • The following deprecated methods and functions in {mod}jax.export have
      been removed:

      • jax.export.DisabledSafetyCheck.shape_assertions: it had no effect
        already.
      • jax.export.Exported.lowering_platforms: use platforms.
      • jax.export.Exported.mlir_module_serialization_version:
        use calling_convention_version.
      • jax.export.Exported.uses_shape_polymorphism:
        use uses_global_constants.
      • the lowering_platforms kwarg for {func}jax.export.export: use
        platforms instead.
    • The kwargs symbolic_scope and symbolic_constraints from
      {func}jax.export.symbolic_args_specs have been removed. They were
      deprecated in June 2024. Use scope and constraints instead.

    • Hashing of tracers, which has been deprecated since version 0.4.30, now
      results in a TypeError.

    • Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
      replaces previous build.py usage. Run python build/build.py --help for
      more details. Brief overview of the new subcommand options:

      • build: Builds JAX wheel packages. For e.g., python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
      • requirements_update: Updates requirements_lock.txt files.
    • {func}jax.scipy.linalg.toeplitz now does implicit batching on multi-dimensional
      inputs. To recover the previous behavior, you can call {func}jax.numpy.ravel
      on the function inputs.

    • {func}jax.scipy.special.gamma and {func}jax.scipy.special.gammasgn now
      return NaN for negative integer inputs, to match the behavior of Scihttps://github.com/scipy/scipy/pull/21827scipy/pull/21827.

    • jax.clear_backends was removed after being deprecated in v0.4.26.

    • We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
      call that we guarantee export stability. This is because this custom call
      relies on Triton IR, which is not guaranteed to be stable. If you need
      to export code that uses this custom call, you can use the disabled_checks
      parameter. See more details in the documentation.

  • New Features

    • {func}jax.jit got a new compiler_options: dict[str, Any] argument, for
      passing compilation options to XLA. For the moment it's undocumented and
      may be in flux.
    • {func}jax.tree_util.register_dataclass now allows metadata fields to be
      declared inline via {func}dataclasses.field. See the function documentation
      for examples.
    • Added {func}jax.numpy.put_along_axis.
    • {func}jax.lax.linalg.eig and the related jax.numpy functions
      ({func}jax.numpy.linalg.eig and {func}jax.numpy.linalg.eigvals) are now
      supported on GPU. See {jax-issue}#24663 for more details.
    • Added two new configuration flags, jax_exec_time_optimization_effort and jax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
  • Bug fixes

    • Fixed a bug where the GPU implementations of LU and QR decomposition would
      result in an indexing overflow for batch sizes close to int32 max. See
      {jax-issue}#24843 for more details.
  • Deprecations

    • jax.lib.xla_extension.ArrayImpl and jax.lib.xla_client.ArrayImpl are deprecated;
      use jax.Array instead.
    • jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError
      instead.

v0.4.35

Compare Source

  • Breaking Changes

    • {func}jax.numpy.isscalar now returns True for any array-like object with
      zero dimensions. Previously it only returned True for zero-dimensional
      array-like objects with a weak dtype.
    • jax.experimental.host_callback has been deprecated since March 2024, with
      JAX version 0.4.26. Now we removed it.
      See {jax-issue}#20385 for a discussion of alternatives.
  • Changes:

    • jax.lax.FftType was introduced as a public name for the enum of FFT
      operations. The semi-public API jax.lib.xla_client.FftType has been
      deprecated.
    • TPU: JAX now installs TPU support from the libtpu package rather than
      libtpu-nightly. For the next few releases JAX will pin an empty version of
      libtpu-nightly as well as libtpu to ease the transition; that dependency
      will be removed in Q1 2025.
  • Deprecations:

    • The semi-public API jax.lib.xla_client.PaddingType has been deprecated.
      No JAX APIs consume this type, so there is no replacement.
    • The default behavior of {func}jax.pure_callback and
      {func}jax.extend.ffi.ffi_call under vmap has been deprecated and so has
      the vectorized parameter to those functions. The vmap_method parameter
      should be used instead for better defined behavior. See the discussion in
      {jax-issue}#23881 for more details.
    • The semi-public API jax.lib.xla_client.register_custom_call_target has
      been deprecated. Use the JAX FFI instead.
    • The semi-public APIs jax.lib.xla_client.dtype_to_etype,
      jax.lib.xla_client.ops,
      jax.lib.xla_client.shape_from_pyval, jax.lib.xla_client.PrimitiveType,
      jax.lib.xla_client.Shape, jax.lib.xla_client.XlaBuilder, and
      jax.lib.xla_client.XlaComputation have been deprecated. Use StableHLO
      instead.

v0.4.34

Compare Source

  • New Functionality

    • This release includes wheels for Python 3.13. Free-threading mode is not yet
      supported.
    • jax.errors.JaxRuntimeError has been added as a public alias for the
      formerly private XlaRuntimeError type.
  • Breaking changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
      • array[0] on a pmap result now introduces a reshape (use array[0:1]
        instead).
      • The per-shard shape (accessable via jax_array.addressable_shards or
        jax_array.addressable_data(0)) now has a leading (1, ...). Update code
        that directly accesses shards accordingly. The rank of the per-shard-shape
        now matches that of the global shape which is the same behavior as jit.
        This avoids costly reshapes when passing results from pmap into jit.
    • jax.experimental.host_callback has been deprecated since March 2024, with
      JAX version 0.4.26. Now we set the default value of the
      --jax_host_callback_legacy configuration value to True, which means that
      if your code uses jax.experimental.host_callback APIs, those API calls
      will be implemented in terms of the new jax.experimental.io_callback API.
      If this breaks your code, for a very limited time, you can set the
      --jax_host_callback_legacy to True. Soon we will remove that
      configuration option, so you should instead transition to using the
      new JAX callback APIs. See {jax-issue}#20385 for a discussion.
  • Deprecations

    • In {func}jax.numpy.trim_zeros, non-arraylike arguments or arraylike
      arguments with ndim != 1 are now deprecated, and in the future will result
      in an error.
    • Internal pretty-printing tools jax.core.pp_* have been removed, after
      being deprecated in JAX v0.4.30.
    • jax.lib.xla_client.Device is deprecated; use jax.Device instead.
    • jax.lib.xla_client.XlaRuntimeError has been deprecated. Use
      jax.errors.JaxRuntimeError instead.
  • Deletion:

    • jax.xla_computation is deleted. It's been 3 months since it's deprecation
      in 0.4.30 JAX release.
      Please use the AOT APIs to get the same functionality as jax.xla_computation.
      • jax.xla_computation(fn)(*args, **kwargs) can be replaced with
        jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
      • You can also use .out_info property of jax.stages.Lowered to get the
        output information (like tree structure, shape and dtype).
      • For cross-backend lowering, you can replace
        jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with
        jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
    • {class}jax.ShapeDtypeStruct no longer accepts the named_shape argument.
      The argument was only used by xmap which was removed in 0.4.31.
    • jax.tree.map(f, None, non-None), which previously emitted a
      DeprecationWarning, now raises an error in a future version of jax. None
      is only a tree-prefix of itself. To preserve the current behavior, you can
      ask jax.tree.map to treat None as a leaf value by writing:
      jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
    • jax.sharding.XLACompatibleSharding has been removed. Please use
      jax.sharding.Sharding.
  • Bug fixes

    • Fixed a bug where {func}jax.numpy.cumsum would produce incorrect outputs
      if a non-boolean input was provided and dtype=bool was specified.
    • Edit implementation of {func}jax.numpy.ldexp to get correct gradient.

v0.4.33

Compare Source

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu.

This release fixes an inaccurate result for F64 tanh on CPU (#​23590).

v0.4.32

Compare Source

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

  • New Functionality

    • Added {func}jax.extend.ffi.ffi_call and {func}jax.extend.ffi.ffi_lowering
      to support the use of the new {ref}ffi-tutorial to interface with custom
      C++ and CUDA code from JAX.
  • Changes

    • jax_enable_memories flag is set to True by default.
    • {mod}jax.numpy now supports v2023.12 of the Python Array API Standard.
      See {ref}python-array-api for more information.
    • Computations on the CPU backend may now be dispatched asynchronously in
      more cases. Previously non-parallel computations were always dispatched
      synchronously. You can recover the old behavior by setting
      jax.config.update('jax_cpu_enable_async_dispatch', False).
    • Added new {func}jax.process_indices function to replace the
      jax.host_ids() function that was deprecated in JAX v0.2.13.
    • To align with the behavior of numpy.fabs, jax.numpy.fabs has been
      modified to no longer support complex dtypes.
    • jax.tree_util.register_dataclass now checks that data_fields
      and meta_fields includes all dataclass fields with init=True
      and only them, if nodetype is a dataclass.
    • Several {mod}jax.numpy functions now have full {class}~jax.numpy.ufunc
      interfaces, including {obj}~jax.numpy.add, {obj}~jax.numpy.multiply,
      {obj}~jax.numpy.bitwise_and, {obj}~jax.numpy.bitwise_or,
      {obj}~jax.numpy.bitwise_xor, {obj}~jax.numpy.logical_and,
      {obj}~jax.numpy.logical_and, and {obj}~jax.numpy.logical_and.
      In future releases we plan to expand these to other ufuncs.
    • Added {func}jax.lax.optimization_barrier, which allows users to prevent
      compiler optimizations such as common-subexpression elimination and to
      control scheduling.
  • Breaking changes

    • The MHLO MLIR dialect (jax.extend.mlir.mhlo) has been removed. Use the
      stablehlo dialect instead.
  • Deprecations

    • Complex inputs to {func}jax.numpy.clip and {func}jax.numpy.hypot are
      no longer allowed, after being deprecated since JAX v0.4.27.
    • Deprecated the following APIs:
      • jax.lib.xla_bridge.xla_client: use {mod}jax.lib.xla_client directly.
      • jax.lib.xla_bridge.get_backend: use {func}jax.extend.backend.get_backend.
      • jax.lib.xla_bridge.default_backend: use {func}jax.extend.backend.default_backend.
    • The jax.experimental.array_api module is deprecated, and importing it is no
      longer required to use the Array API. jax.numpy supports the array API
      directly; see {ref}python-array-api for more information.
    • The internal utilities jax.core.check_eqn, jax.core.check_type, and
      jax.core.check_valid_jaxtype are now deprecated, and will be removed in
      the future.
    • jax.numpy.round_ has been deprecated, following removal of the corresponding
      API in NumPy 2.0. Use {func}jax.numpy.round instead.
    • Passing a DLPack capsule to {func}jax.dlpack.from_dlpack is deprecated.
      The argument to {func}jax.dlpack.from_dlpack should be an array from
      another framework that implements the __dlpack__ protocol.

v0.4.31

Compare Source

  • Deletion

    • xmap has been deleted. Please use {func}shard_map as the replacement.
  • Changes

    • The minimum CuDNN version is v9.1. This was true in previous releases also,
      but we now declare this version constraint formally.
    • The minimum Python version is now 3.10. 3.10 will remain the minimum
      supported version until July 2025.
    • The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
      supported version until December 2024.
    • The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum
      supported version until January 2025.
    • {func}jax.numpy.ceil, {func}jax.numpy.floor and {func}jax.numpy.trunc now return the output
      of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
    • libdevice.10.bc is no longer bundled with CUDA wheels. It must be
      installed either as a part of local CUDA installation, or via NVIDIA's CUDA
      pip wheels.
    • {class}jax.experimental.pallas.BlockSpec now expects block_shape to
      be passed before index_map. The old argument order is deprecated and
      will be removed in a future release.
    • Updated the repr of gpu devices to be more consistent
      with TPUs/CPUs. For example, cuda(id=0) will now be CudaDevice(id=0).
    • Added the device property and to_device method to {class}jax.Array, as
      part of JAX's Array API support.
  • Deprecations

    • Removed a number of previously-deprecated internal APIs related to
      polymorphic shapes. From {mod}jax.core: removed canonicalize_shape,
      dimension_as_value, definitely_equal, and symbolic_equal_dim.
    • HLO lowering rules should no longer wrap singleton ir.Values in tuples.
      Instead, return singleton ir.Values unwrapped. Support for wrapped values
      will be removed in a future version of JAX.
    • {func}jax.experimental.jax2tf.convert with native_serialization=False
      or enable_xla=False is now deprecated and this support will be removed in
      a future version.
      Native serialization has been the default since JAX 0.4.16 (September 2023).
    • The previously-deprecated function jax.random.shuffle has been removed;
      instead use jax.random.permutation with independent=True.

v0.4.30

Compare Source

  • Changes

    • JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
      bumped to 0.4.0 but this has been rolled back in this release to give users
      of both TensorFlow and JAX more time to migrate to a newer TensorFlow
      release.
    • jax.experimental.mesh_utils can now create an efficient mesh for TPU v5e.
    • jax now depends on jaxlib directly. This change was enabled by the CUDA
      plugin switch: there are no longer multiple jaxlib variants. You can install
      a CPU-only jax with pip install jax, no extras required.
    • Added an API for exporting and serializing JAX functions. This used
      to exist in jax.experimental.export (which is being deprecated),
      and will now live in jax.export.
      See the documentation.
  • Deprecations

    • Internal pretty-printing tools jax.core.pp_* are deprecated, and will be removed
      in a future release.
    • Hashing of tracers is deprecated, and will lead to a TypeError in a future JAX
      release. This previously was the case, but there was an inadvertent regression in
      the last several JAX releases.
    • jax.experimental.export is deprecated. Use {mod}jax.export instead.
      See the migration guide.
    • Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
      x and y, x.astype(y) will raise a warning. To silence it use x.astype(y.dtype).
    • jax.xla_computation is deprecated and will be removed in a future release.
      Please use the AOT APIs to get the same functionality as jax.xla_computation.
      • jax.xla_computation(fn)(*args, **kwargs) can be replaced with
        jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
      • You can also use .out_info property of jax.stages.Lowered to get the
        output information (like tree structure, shape and dtype).
      • For cross-backend lowering, you can replace
        jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with
        jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').

v0.4.29

Compare Source

  • Changes

    • We anticipate that this will be the last release of JAX and jaxlib
      supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
      plugin jaxlib (e.g. pip install jax[cuda12]).
    • JAX now requires ml_dtypes version 0.4.0 or newer.
    • Removed backwards-compatibility support for old usage of the
      jax.experimental.export API. It is not possible anymore to use
      from jax.experimental.export import export, and instead you should use
      from jax.experimental import export.
      The removed functionality has been deprecated since 0.4.24.
    • Added is_leaf argument to {func}jax.tree.all & {func}jax.tree_util.tree_all.
  • Deprecations

    • jax.sharding.XLACompatibleSharding is deprecated. Please use
      jax.sharding.Sharding.
    • jax.experimental.Exported.in_shardings has been renamed as
      jax.experimental.Exported.in_shardings_hlo. Same for out_shardings.
      The old names will be removed after 3 months.
    • Removed a number of previously-deprecated APIs:
      • from {mod}jax.core: non_negative_dim, DimSize, Shape
      • from {mod}jax.lax: tie_in
      • from {mod}jax.nn: normalize
      • from {mod}jax.interpreters.xla: backend_specific_translations,
        translations, register_translation, xla_destructure,
        TranslationRule, TranslationContext, XlaOp.
    • The tol argument of {func}jax.numpy.linalg.matrix_rank is being
      deprecated and will soon be removed. Use rtol instead.
    • The rcond argument of {func}jax.numpy.linalg.pinv is being
      deprecated and will soon be removed. Use rtol instead.
    • The deprecated jax.config submodule has been removed. To configure JAX
      use import jax and then reference the config object via jax.config.
    • {mod}jax.random APIs no longer accept batched keys, where previously
      some did unintentionally. Going forward, we recommend explicit use of
      {func}jax.vmap in such cases.
    • In {func}jax.scipy.special.beta, the x and y parameters have been
      renamed to a and b for consistency with other beta APIs.
  • New Functionality

    • Added {func}jax.experimental.Exported.in_shardings_jax to construct
      shardings that can be used with the JAX APIs from the HloShardings
      that are stored in the Exported objects.

v0.4.28

Compare Source

  • Bug fixes

    • Reverted a change to make_jaxpr that was breaking Equinox (#​21116).
  • Deprecations & removals

    • The kind argument to {func}jax.numpy.sort and {func}jax.numpy.argsort
      is now removed. Use stable=True or stable=False instead.
    • Removed get_compute_capability from the jax.experimental.pallas.gpu
      module. Use the compute_capability attribute of a GPU device, returned
      by {func}jax.devices or {func}jax.local_devices, instead.
    • The newshape argument to {func}jax.numpy.reshapeis being deprecated
      and will soon be removed. Use shape instead.
  • Changes

    • The minimum jaxlib version of this release is 0.4.27.

v0.4.27

Compare Source

  • New Functionality

    • Added {func}jax.numpy.unstack and {func}jax.numpy.cumulative_sum,
      following their addition in the array API 2023 standard, soon to be
      adopted by NumPy.
    • Added a new config option jax_cpu_collectives_implementation to select the
      implementation of cross-process collective operations used by the CPU backend.
      Choices available are 'none'(default), 'gloo' and 'mpi' (requires jaxlib 0.4.26).
      If set to 'none', cross-process collective operations are disabled.
  • Changes

    • {func}jax.pure_callback, {func}jax.experimental.io_callback
      and {func}jax.debug.callback now use {class}jax.Array instead
      of {class}np.ndarray. You can recover the old behavior by transforming
      the arguments via jax.tree.map(np.asarray, args) before passing them
      to the callback.
    • complex_arr.astype(bool) now follows the same semantics as NumPy, returning
      False where complex_arr is equal to 0 + 0j, and True otherwise.
    • core.Token now is a non-trivial class which wraps a jax.Array. It could
      be created and threaded in and out of computations to build up dependency.
      The singleton object core.token has been removed, users now should create
      and use fresh core.Token objects instead.
    • On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
      by default. This choice can improve runtime memory usage at a compile-time
      cost. Prior behavior, which produces a kernel call, can be recovered with
      jax.config.update('jax_threefry_gpu_kernel_lowering', True). If the new
      default causes issues, please file a bug. Otherwise, we intend to remove
      this flag in a future release.
  • Deprecations & Removals

    • Pallas now exclusively uses XLA for compiling kernels on GPU. The old
      lowering pass via Triton Python APIs has been removed and the
      JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.
    • {func}jax.numpy.clip has a new argument signature: a, a_min, and
      a_max are deprecated in favor of x (positional only), min, and
      max ({jax-issue}20550).
    • The device() method of JAX arrays has been removed, after being deprecated
      since JAX v0.4.21. Use arr.devices() instead.
    • The initial argument to {func}jax.nn.softmax and {func}jax.nn.log_softmax
      is deprecated; empty inputs to softmax are now supported without setting this.
    • In {func}jax.jit, passing invalid static_argnums or static_argnames
      now leads to an error rather than a warning.
    • The minimum jaxlib version is now 0.4.23.
    • The {func}jax.numpy.hypot function now issues a deprecation warning when
      passing complex-valued inputs to it. This will raise an error when the
      deprecation is completed.
    • Scalar arguments to {func}jax.numpy.nonzero, {func}jax.numpy.where, and
      related functions now raise an error, following a similar change in NumPy.
    • The config option jax_cpu_enable_gloo_collectives is deprecated.
      Use jax.config.update('jax_cpu_collectives_implementation', 'gloo') instead.
    • The jax.Array.device_buffer and jax.Array.device_buffers methods have
      been removed after being deprecated in JAX v0.4.22. Instead use
      {attr}jax.Array.addressable_shards and {meth}jax.Array.addressable_data.
    • The condition, x, and y parameters of jax.numpy.where are now
      positional-only, following deprecation of the keywords in JAX v0.4.21.
    • Non-array arguments to functions in {mod}jax.lax.linalg now must be
      specified by keyword. Previously, this raised a DeprecationWarning.
    • Array-like arguments are now required in several :func:jax.numpy APIs,
      including {func}~jax.numpy.apply_along_axis,
      {func}~jax.numpy.apply_over_axes, {func}~jax.numpy.inner,
      {func}~jax.numpy.outer, {func}~jax.numpy.cross,
      {func}~jax.numpy.kron, and {func}~jax.numpy.lexsort.
  • Bug fixes

    • {func}jax.numpy.astype will now always return a copy when copy=True.
      Previously, no copy would be made when the output array would have the same
      dtype as the input array. This may result in some increased memory usage.
      The default value is set to copy=False to preserve backwards compatibility.

v0.4.26

Compare Source

  • New Functionality

    • Added {func}jax.numpy.trapezoid, following the addition of this function in
      NumPy 2.0.
  • Changes

    • Complex-valued {func}jax.numpy.geomspace now chooses the logarithmic spiral
      branch consistent with that of NumPy 2.0.
    • The behavior of lax.rng_bit_generator, and in turn the 'rbg'
      and 'unsafe_rbg' PRNG implementations, under jax.vmap has
      changed
      so that
      mapping over keys results in random generation only from the first
      key in the batch.
    • Docs now use jax.random.key for construction of PRNG key arrays
      rather than jax.random.PRNGKey.
  • Deprecations & Removals

    • {func}jax.tree_map is deprecated; use jax.tree.map instead, or for backward
      compatibility with older JAX versions, use {func}jax.tree_util.tree_map.
    • {func}jax.clear_backends is deprecated as it does not necessarily do what
      its name suggests and can lead to unexpected consequences, e.g., it will not
      destroy existing backends and release corresponding owned resources. Use
      {func}jax.clear_caches if you only want to clean up compilation caches.
      For backward compatibility or you really need to switch/reinitialize the
      default backend, use {func}jax.extend.backend.clear_backends.
    • The jax.experimental.maps module and jax.experimental.maps.xmap are
      deprecated. Use jax.experimental.shard_map or jax.vmap with the
      spmd_axis_name argument for expressing SPMD device-parallel computations.
    • The jax.experimental.host_callback module is deprecated.
      Use instead the new JAX external callbacks.
      Added JAX_HOST_CALLBACK_LEGACY flag to assist in the transition to the
      new callbacks. See {jax-issue}#20385 for a discussion.
    • Passing arguments to {func}jax.numpy.array_equal and {func}jax.numpy.array_equiv
      that cannot be converted to a JAX array now results in an exception.
    • The deprecated flag jax_parallel_functions_output_gda has been removed.
      This flag was long deprecated and did nothing; its use was a no-op.
    • The previously-deprecated imports jax.interpreters.ad.config and
      jax.interpreters.ad.source_info_util have now been removed. Use jax.config
      and jax.extend.source_info_util instead.
    • JAX export does not support older serialization versions anymore. Version 9
      has been supported since October 27th, 2023 and has become the default
      since February 1, 2024.
      See a description of the versions.
      This change could break clients that set a specific
      JAX serialization version lower than 9.

v0.4.25

Compare Source

  • New Features

    • Added CUDA Array
      Interface

      import support (requires jaxlib 0.4.24).
    • JAX arrays now support NumPy-style scalar boolean indexing, e.g. x[True] or x[False].
    • Added {mod}jax.tree module, with a more convenient interface for referencing functions
      in {mod}jax.tree_util.
    • {func}jax.tree.transpose (i.e. {func}jax.tree_util.tree_transpose) now accepts
      inner_treedef=None, in which case the inner treedef will be automatically inferred.
  • Changes

    • Pallas now uses XLA instead of the Triton Python APIs to compile Triton
      kernels. You can revert to the old behavior by setting the
      JAX_TRITON_COMPILE_VIA_XLA environment variable to "0".
    • Several deprecated APIs in {mod}jax.interpreters.xla that were removed in v0.4.24
      have been re-added in v0.4.25, including backend_specific_translations,
      translations, register_translation, xla_destructure, TranslationRule,
      TranslationContext, and XLAOp. These are still considered deprecated, and
      will be removed again in the future when better replacements are available.
      Refer to {jax-issue}#19816 for discussion.
  • Deprecations & Removals

    • {func}jax.numpy.linalg.solve now shows a deprecation warning for batched 1D
      solves with b.ndim > 1. In the future these will be treated as batched 2D
      solves.
    • Conversion of a non-scalar array to a Python scalar now raises an error, regardless
      of the size of the array. Previously a deprecation warning was raised in the case of
      non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
    • The previously deprecated configuration APIs have been removed
      following a standard 3 months deprecation cycle (see {ref}api-compatibility).
      These include
      • the jax.config.config object and
      • the define_*_state and DEFINE_* methods of {data}jax.config.
    • Importing the jax.config submodule via import jax.config is deprecated.
      To configure JAX use import jax and then reference the config object
      via jax.config.
    • The minimum jaxlib version is now 0.4.20.

v0.4.24

Compare Source

  • Changes

    • JAX lowering to StableHLO does not depend on physical devices anymore.
      If your primitive wraps custom_partitioning or JAX callbacks in the lowering
      rule i.e. function passed to rule parameter of mlir.register_lowering then add your
      primitive to jax._src.dispatch.prim_requires_devices_during_lowering set.
      This is needed because custom_partitioning and JAX callbacks need physical
      devices to create Shardings during lowering.
      This is a temporary state until we can create Shardings without physical
      devices.
    • {func}jax.numpy.argsort and {func}jax.numpy.sort now support the stable
      and descending arguments.
    • Several changes to the handling of shape polymorphism (used in
      {mod}jax.experimental.jax2tf and {mod}jax.experimental.export):
      • cleaner pretty-printing of symbolic expressions ({jax-issue}#19227)
      • added the ability to specify symbolic constraints on the dimension variables.
        This makes shape polymorphism more expressive, and gives a way to workaround
        limitations in the reasoning about inequalities.
        See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
      • with the addition of symbolic constraints ({jax-issue}#19235) we now
        consider dimension variables from different scopes to be different, even
        if they have the same name. Symbolic expressions from different scopes
        cannot interact, e.g., in arithmetic operations.
        Scopes are introduced by {func}jax.experimental.jax2tf.convert,
        {func}jax.experimental.export.symbolic_shape, {func}jax.experimental.export.symbolic_args_specs.
        The scope of a symbolic expression e can be read with e.scope and passed
        into the above functions to direct them to construct symbolic expressions in
        a given scope.
        See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
      • simplified and faster equality comparisons, where we consider two symbolic dimensions
        to be equal if the normalized form of their difference reduces to 0
        ({jax-issue}#19231; note that this may result in user-visible behavior
        changes)
      • improved the error messages for inconclusive inequality comparisons
        ({jax-issue}#19235).
      • the core.non_negative_dim API (introduced recently)
        was deprecated and core.max_dim and core.min_dim were introduced
        ({jax-issue}#18953) to express max and min for symbolic dimensions.
        You can use core.max_dim(d, 0) instead of core.non_negative_dim(d).
      • the shape_poly.is_poly_dim is deprecated in favor of export.is_symbolic_dim
        ({jax-issue}#19282).
      • the export.args_specs is deprecated in favor of export.symbolic_args_specs ({jax-issue}#​19283`).
      • the shape_poly.PolyShape and jax2tf.PolyShape are deprecated, use
        strings for polymorphic shapes specifications ({jax-issue}#19284).
      • JAX default native serialization version is now 9. This is relevant
        for {mod}jax.experimental.jax2tf and {mod}jax.experimental.export.
        See description of version numbers.
    • Refactored the API for jax.experimental.export. Instead of
      from jax.experimental.export import export you should use now
      from jax.experimental import export. The old way of importing will
      continue to work for a deprecation period of 3 months.
    • Added {func}jax.scipy.stats.sem.
    • {func}jax.numpy.unique with return_inverse = True returns inverse indices
      reshaped to the dimension of the input, following a similar change to
      {func}numpy.unique in NumPy 2.0.
    • {func}jax.numpy.sign now returns x / abs(x) for nonzero complex inputs. This is
      consistent with the behavior of {func}numpy.sign in NumPy version 2.0.
    • {func}jax.scipy.special.logsumexp with return_sign=True now uses the NumPy 2.0
      convention for the complex sign, x / abs(x). This is consistent with the behavior
      of {func}scipy.special.logsumexp in SciPy v1.13.
    • JAX now supports the bool DLPack type for both import and export.
      Previously bool values could not be imported and were exported as integers.
  • Deprecations & Removals

    • A number of previously deprecated functions have been removed, following a
      standard 3+ month deprecation cycle (see {ref}api-compatibility).
      This includes:
      • From {mod}jax.core: TracerArrayConversionError,
        TracerIntegerConversionError, UnexpectedTracerError,
        as_hashable_function, collections, dtypes, lu, map,
        namedtuple, partial, pp, ref, safe_zip, safe_map,
        source_info_util, total_ordering, traceback_util, tuple_delete,
        tuple_insert, and zip.
      • From {mod}jax.lax: dtypes, itertools, naryop, naryop_dtype_rule,
        standard_abstract_eval, standard_naryop, standard_primitive,
        standard_unop, unop, and unop_dtype_rule.
      • The jax.linear_util submodule and all its contents.
      • The jax.prng submodule and all its contents.
      • From {mod}jax.random: PRNGKeyArray, KeyArray, default_prng_impl,
        threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and
        unsafe_rbg_key.
      • From {mod}jax.tree_util: register_keypaths, AttributeKeyPathEntry, and
        GetItemKeyPathEntry.
      • from {mod}jax.interpreters.xla: backend_specific_translations, translations,
        register_translation, xla_destructure, TranslationRule, TranslationContext,
        axis_groups, ShapedArray, ConcreteArray, AxisEnv, backend_compile,
        and XLAOp.
      • from {mod}jax.numpy: NINF, NZERO, PZERO, row_stack, issubsctype,
        trapz, and in1d.
      • from {mod}jax.scipy.linalg: tril and triu.
    • The previously-deprecated method PRNGKeyArray.unsafe_raw_array has been
      removed. Use {func}jax.random.key_data instead.
    • bool(empty_array) now raises an error rather than returning False. This
      previously raised a deprecation warning, and follows a similar change in NumPy.
    • Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
      the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
      removed in the future. Use the "stablehlo" dialect instead.
    • {mod}jax.random: passing batched keys directly to random number generation functions,
      such as {func}~jax.random.bits, {func}~jax.random.gamma, and others, is deprecated
      and will emit a FutureWarning. Use jax.vmap for explicit batching.
    • {func}jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0.

Configuration

📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.

Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 Ignore: Close this PR and you won't be reminded about this update again.


  • If you want to rebase/retry this PR, check this box

This PR was generated by Mend Renovate. View the repository job log.

@renovate renovate bot requested a review from a team as a code owner July 14, 2024 02:14
@renovate renovate bot added the dependencies Pull requests that update a dependency file label Jul 14, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from 4e9429b to f525d24 Compare July 30, 2024 00:48
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.30 chore(deps): update dependency jax to v0.4.31 Jul 30, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from f525d24 to 9006238 Compare August 16, 2024 13:52
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from 9006238 to aeb84db Compare September 11, 2024 23:10
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.31 chore(deps): update dependency jax to v0.4.32 Sep 11, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from aeb84db to b67436d Compare September 12, 2024 22:41
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.32 chore(deps): update dependency jax to v0.4.31 Sep 12, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from b67436d to d9b96e5 Compare September 16, 2024 22:16
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.31 chore(deps): update dependency jax to v0.4.33 Sep 16, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from d9b96e5 to bc8ac6f Compare October 4, 2024 15:21
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.33 chore(deps): update dependency jax to v0.4.34 Oct 4, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from bc8ac6f to 7786f77 Compare October 22, 2024 23:15
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.34 chore(deps): update dependency jax to v0.4.35 Oct 22, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from 7786f77 to b6cfc8a Compare December 6, 2024 01:47
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.35 chore(deps): update dependency jax to v0.4.36 Dec 6, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from b6cfc8a to 0a2d1af Compare December 10, 2024 05:08
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.36 chore(deps): update dependency jax to v0.4.37 Dec 10, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from 0a2d1af to de34aef Compare December 18, 2024 00:25
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.37 chore(deps): update dependency jax to v0.4.38 Dec 18, 2024
@renovate renovate bot force-pushed the renovate/jax-0.x-lockfile branch from de34aef to 473b292 Compare January 17, 2025 20:52
@renovate renovate bot changed the title chore(deps): update dependency jax to v0.4.38 chore(deps): update dependency jax to v0.5.0 Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Pull requests that update a dependency file
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants