Skip to content

Releases: jax-ml/jax

JAX v0.8.0

15 Oct 23:38

Choose a tag to compare

  • Breaking changes:

    • JAX is changing the default jax.pmap implementation to one implemented in
      terms of jax.jit and jax.shard_map. jax.pmap is in maintenance mode
      and we encourage all new code to use jax.shard_map directly. See the
      migration guide for
      more information.
    • The auto= parameter of jax.experimental.shard_map.shard_map has been
      removed. This means that jax.experimental.shard_map.shard_map no longer
      supports nesting. If you want to nest shard_map calls, please use
      jax.shard_map.
    • JAX no longer allows passing objects that support __jax_array__ directly
      to, e.g. jit-ed functions. Call jax.numpy.asarray on them first.
    • jax.numpy.cov is now returns NaN for empty arrays ({jax-issue}#32305),
      and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).
    • JAX no longer accepts Array values where a dtype value is expected. Call
      .dtype on these values first.
    • The deprecated function jax.interpreters.mlir.custom_call was
      removed.
    • The jax.util, jax.extend.ffi, and jax.experimental.host_callback
      modules have been removed. All public APIs within these modules were
      deprecated and removed in v0.7.0 or earlier.
    • The deprecated symbol jax.custom_derivatives.custom_jvp_call_jaxpr_p
      was removed.
    • jax.experimental.multihost_utils.process_allgather raises an error when
      the input is a jax.Array and not fully-addressable and tiled=False. To fix
      this, pass tiled=True to your process_allgather invocation.
    • from jax.experimental.compilation_cache, the deprecated symbols
      is_initialized and initialize_cache were removed.
    • The deprecated function jax.interpreters.xla.canonicalize_dtype
      was removed.
    • jaxlib.hlo_helpers has been removed. Use jax.ffi instead.
    • The option jax_cpu_enable_gloo_collectives has been removed. Use
      jax_cpu_collectives_implementation instead.
    • The previously-deprecated interpolation argument to
      jax.numpy.percentile and jax.numpy.quantile has been
      removed; use method instead.
    • The JAX-internal for_loop primitive was removed. Its functionality,
      reading from and writing to refs in the loop body, is now directly
      supported by jax.lax.fori_loop. If you need help updating your
      code, please file a bug.
    • jax.numpy.trimzeros now errors for non-1D input.
    • The where argument to jax.numpy.sum and other reductions is now
      required to be boolean. Non-boolean values have resulted in a
      DeprecationWarning since JAX v0.5.0.
    • The deprecated functions in jax.dlpack, jax.errors,
      jax.lib.xla_bridge, jax.lib.xla_client, and
      jax.lib.xla_extension were removed.
    • jax.interpreters.mlir.dense_bool_array was removed. Use MLIR APIs to
      construct attributes instead.
  • Changes

    • jax.numpy.linalg.eig now returns a namedtuple (with attributes
      eigenvalues and eigenvectors) instead of a plain tuple.
    • jax.grad and jax.vjp will now round always primals to
      float32 if float64 mode is not enabled.
    • jax.dlpack.from_dlpack now accepts arrays with non-default layouts,
      for example, transposed.
    • The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
      cusolver. The magma and LAPACK implementations are still available via the
      new implementation argument to jax.lax.linalg.eig
      ({jax-issue}#27265). The use_magma argument is now deprecated in favor
      of implementation.
    • jax.numpy.trim_zeros now follows NumPy 2.2 in supporting
      multi-dimensional inputs.
  • Deprecations

    • jax.experimental.enable_x64 and jax.experimental.disable_x64
      are deprecated in favor of the new non-experimental context manager
      jax.enable_x64.
    • jax.experimental.shard_map.shard_map is deprecated; going forward use
      jax.shard_map.
    • jax.experimental.pjit.pjit is deprecated; going forward use
      jax.jit.

JAX v0.7.2

16 Sep 17:19

Choose a tag to compare

  • Breaking changes:

    • jax.dlpack.from_dlpack no longer accepts a DLPack capsule. This
      behavior was deprecated and is now removed. The function must be called
      with an array implementing __dlpack__ and __dlpack_device__.
  • Changes

    • The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
      for NumPy 2.0 support, the minimum supported SciPy version is now 1.13.

    • JAX now represents constants in its internal jaxpr representation as a
      LiteralArray, which is a private JAX type that duck types as a
      numpy.ndarray. This type may be exposed to users via custom_jvp rules,
      for example, and may break code that uses isinstance(x, np.ndarray). If
      this breaks your code, you may convert these arrays to classic NumPy arrays
      using np.asarray(x).

  • Bug fixes

    • arr.view(dtype=None) now returns the array unchanged, matching NumPy's
      semantics. Previously it returned the array with a float dtype.
    • jax.random.randint now produces a less-biased distribution for 8-bit and
      16-bit integer types ({jax-issue}#27742). To restore the previous biased
      behavior, you may temporarily set the jax_safer_randint configuration to
      False, but note this is a temporary config that will be removed in a
      future release.
  • Deprecations:

    • The parameters enable_xla and native_serialization for jax2tf.convert
      are deprecated and will be removed in a future version of JAX. These were
      used for jax2tf with non-native serialization, which has been now removed.
    • Setting the config state jax_pmap_no_rank_reduction to False is
      deprecated. By default, jax_pmap_no_rank_reduction will be set to True
      and jax.pmap shards will not have their rank reduced, keeping the same
      rank as their enclosing array.

JAX v0.7.1

20 Aug 16:04

Choose a tag to compare

  • New features

    • JAX now ships Python 3.14 and 3.14t wheels.
    • JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
      offered free-threading builds on Linux.
  • Changes

    • Exposed jax.set_mesh which acts as a global setter and a context manager.
      Removed jax.sharding.use_mesh in favor of jax.set_mesh.
    • JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
      supported.
    • jax.lax.dot now implements the general dot product via the optional
      dimension_numbers argument.
  • Deprecations:

    • jax.lax.zeros_like_array is deprecated. Please use
      jax.numpy.zeros_like instead.
    • Attempting to import jax.experimental.host_callback now results in
      a DeprecationWarning, and will result in an ImportError starting in JAX
      v0.8.0. Its APIs have raised NotImplementedError since JAX version 0.4.35.
    • In jax.lax.dot, passing the precision and preferred_element_type
      arguments by position is deprecated. Pass them by explicit keyword instead.
    • Several dozen internal APIs have been deprecated from jax.interpreters.ad,
      jax.interpreters.batching, and jax.interpreters.partial_eval; they
      are used rarely if ever outside JAX itself, and most are deprecated without any
      public replacement.

JAX v0.7.0

22 Jul 20:33

Choose a tag to compare

  • New features:

    • Added jax.P which is an alias for jax.sharding.PartitionSpec.
    • Added jax.tree.reduce_associative.
  • Breaking changes:

    • JAX is migrating from GSPMD to Shardy by default. See the
      migration guide
      for more information.
    • JAX autodiff is switching to using direct linearization by default (instead of
      implementing linearization via JVP and partial eval).
      See migration guide
      for more information.
    • jax.stages.OutInfo has been replaced with jax.ShapeDtypeStruct.
    • jax.jit now requires fun to be passed by position, and additional
      arguments to be passed by keyword. Doing otherwise will result in an error
      starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.
    • The minimum Python version is now 3.11. 3.11 will remain the minimum
      supported version until July 2026.
    • Layout API renames:
      • Layout, .layout, .input_layouts and .output_layouts have been
        renamed to Format, .format, .input_formats and .output_formats
      • DeviceLocalLayout, .device_local_layout have been renamed to Layout
        and .layout
    • jax.experimental.shard module has been deleted and all the APIs have been
      moved to the jax.sharding endpoint. So use jax.sharding.reshard,
      jax.sharding.auto_axes and jax.sharding.explicit_axes instead of their
      experimental endpoints.
    • lax.infeed and lax.outfeed were removed, after being deprecated in
      JAX 0.6. The transfer_to_infeed and transfer_from_outfeed methods were
      also removed the Device objects.
    • The jax.extend.core.primitives.pjit_p primitive has been renamed to
      jit_p, and its name attribute has changed from "pjit" to "jit".
      This affects the string representations of jaxprs. The same primitive is no
      longer exported from the jax.experimental.pjit module.
    • The (undocumented) function jax.extend.backend.add_clear_backends_callback
      has been removed. Users should use jax.extend.backend.register_backend_cache
      instead.
  • Deprecations:

    • {obj}jax.dlpack.SUPPORTED_DTYPES is deprecated; please use the new
      jax.dlpack.is_supported_dtype function.
    • jax.scipy.special.sph_harm has been deprecated following a similar
      deprecation in SciPy; use jax.scipy.special.sph_harm_y instead.
    • From {mod}jax.interpreters.xla, the previously deprecated symbols
      abstractify and pytype_aval_mappings have been removed.
    • jax.interpreters.xla.canonicalize_dtype is deprecated. For
      canonicalizing dtypes, prefer jax.dtypes.canonicalize_dtype.
      For checking whether an object is a valid jax input, prefer
      jax.core.valid_jaxtype.
    • From {mod}jax.core, the previously deprecated symbols AxisName,
      ConcretizationTypeError, axis_frame, call_p, closed_call_p,
      get_type, trace_state_clean, typematch, and typecheck have been
      removed.
    • From {mod}jax.lib.xla_client, the previously deprecated symbols
      DeviceAssignment, get_topology_for_devices, and mlir_api_version
      have been removed.
    • jax.extend.ffi was removed after being deprecated in v0.5.0.
      Use {mod}jax.ffi instead.
    • jax.lib.xla_bridge.get_compile_options is deprecated, and replaced by
      jax.extend.backend.get_compile_options.

JAX v0.6.2

17 Jun 23:06

Choose a tag to compare

  • New features:

    • Added jax.tree.broadcast which implements a pytree prefix broadcasting helper.
  • Changes

    • The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.

JAX v0.6.1

21 May 18:30

Choose a tag to compare

  • New features:

    • Added jax.lax.axis_size which returns the size of the mapped axis
      given its name.
  • Changes

    • Additional checking for the versions of CUDA package dependencies was
      reenabled, having been accidentally disabled in a previous release.
    • JAX nightly packages are now published to artifact registry. To install
      these packages, see the JAX installation guide.
    • jax.sharding.PartitionSpec no longer inherits from a tuple.
    • jax.ShapeDtypeStruct is immutable now. Please use .update method to
      update your ShapeDtypeStruct instead of doing in-place updates.
  • Deprecations

    • jax.custom_derivatives.custom_jvp_call_jaxpr_p is deprecated, and will be
      removed in JAX v0.7.0.

JAX v0.6.0

17 Apr 00:04

Choose a tag to compare

  • Breaking changes

    • jax.numpy.array no longer accepts None. This behavior was
      deprecated since November 2023 and is now removed.
    • Removed the config.jax_data_dependent_tracing_fallback config option,
      which was added temporarily in v0.4.36 to allow users to opt out of the
      new "stackless" tracing machinery.
    • Removed the config.jax_eager_pmap config option.
    • Disallow the calling of lower and trace AOT APIs on the result
      of jax.jit if there have been subsequent wrappers applied.
      Previously this worked, but silently ignored the wrappers.
      The workaround is to apply jax.jit last among the wrappers,
      and similarly for jax.pmap.
      See #27873.
    • The cuda12_pip extra for jax has been removed; use pip install jax[cuda12]
      instead.
  • Changes

    • The minimum CuDNN version is v9.8.
    • JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
      supported.
    • JAX package extras are now updated to use dash instead of underscore to
      align with PEP 685. For instance, if you were previously using pip install jax[cuda12_local]
      to install JAX, run pip install jax[cuda12-local] instead.
    • jax.jit now requires fun to be passed by position, and additional
      arguments to be passed by keyword. Doing otherwise will result in a
      DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
  • Deprecations

    • jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten
      instead.
    • Implemented host callback handlers for CPU and GPU devices using XLA's FFI
      and removed existing CPU/GPU handlers using XLA's custom call.
    • All APIs in jax.lib.xla_extension are now deprecated.
    • jax.interpreters.mlir.hlo and jax.interpreters.mlir.func_dialect,
      which were accidental exports, have been removed. If needed, they are
      available from jax.extend.mlir.
    • jax.interpreters.mlir.custom_call is deprecated. The APIs provided by
      jax.ffi should be used instead.
    • The deprecated use of jax.ffi.ffi_call with inline arguments is no
      longer supported. jax.ffi.ffi_call now unconditionally returns a
      callable.
    • The following exports in jax.lib.xla_client are deprecated:
      get_topology_for_devices, heap_profile, mlir_api_version, Client,
      CompileOptions, DeviceAssignment, Frame, HloSharding, OpSharding,
      Traceback.
    • The following internal APIs in jax.util are deprecated:
      HashableFunction, as_hashable_function, cache, safe_map, safe_zip,
      split_dict, split_list, split_list_checked, split_merge, subvals,
      toposort, unzip2, wrap_name, and wraps.
    • jax.dlpack.to_dlpack has been deprecated. You can usually pass a JAX
      Array directly to the from_dlpack function of another framework. If you
      need the functionality of to_dlpack, use the __dlpack__ attribute of an
      array.
    • jax.lax.infeed, jax.lax.infeed_p, jax.lax.outfeed, and
      jax.lax.outfeed_p are deprecated and will be removed in JAX v0.7.0.
    • Several previously-deprecated APIs have been removed, including:
      • From jax.lib.xla_client: ArrayImpl, FftType, PaddingType,
        PrimitiveType, XlaBuilder, dtype_to_etype,
        ops, register_custom_call_target, shape_from_pyval, Shape,
        XlaComputation.
      • From jax.lib.xla_extension: ArrayImpl, XlaRuntimeError.
      • From jax: jax.treedef_is_leaf, jax.tree_flatten, jax.tree_map,
        jax.tree_leaves, jax.tree_structure, jax.tree_transpose, and
        jax.tree_unflatten. Replacements can be found in jax.tree or
        jax.tree_util.
      • From jax.core: AxisSize, ClosedJaxpr, EvalTrace, InDBIdx, InputType,
        Jaxpr, JaxprEqn, Literal, MapPrimitive, OpaqueTraceState, OutDBIdx,
        Primitive, Token, TRACER_LEAK_DEBUGGER_WARNING, Var, concrete_aval,
        dedup_referents, escaped_tracer_error, extend_axis_env_nd, full_lower, get_referent, jaxpr_as_fun, join_effects, lattice_join,
        leaked_tracer_error, maybe_find_leaked_tracers, raise_to_shaped,
        raise_to_shaped_mappings, reset_trace_state, str_eqn_compact,
        substitute_vars_in_output_ty, typecompat, and used_axis_names_jaxpr. Most
        have no public replacement, though a few are available at jax.extend.core.
      • The vectorized argument to jax.pure_callback and
        jax.ffi.ffi_call. Use the vmap_method parameter instead.

JAX v0.5.3

19 Mar 18:20

Choose a tag to compare

  • New Features

    • Added a allow_negative_indices option to jax.lax.dynamic_slice,
      jax.lax.dynamic_update_slice and related functions. The default is
      true, matching the current behavior. If set to false, JAX does not need to
      emit code clamping negative indices, which improves code size.
    • Added a replace option to jax.random.categorical to enable sampling
      without replacement.

JAX v0.5.2

05 Mar 02:36

Choose a tag to compare

Patch release of 0.5.1

  • Bug fixes
    • Fixes TPU metric logging and tpu-info, which was broken in 0.5.1

JAX v0.5.1

24 Feb 21:03

Choose a tag to compare

  • New Features

    • Added an experimental jax.experimental.custom_dce.custom_dce
      decorator to support customizing the behavior of opaque functions under
      JAX-level dead code elimination (DCE). See #25956 for more
      details.
    • Added low-level reduction APIs in {mod}jax.lax: jax.lax.reduce_sum,
      jax.lax.reduce_prod, jax.lax.reduce_max, jax.lax.reduce_min,
      jax.lax.reduce_and, jax.lax.reduce_or, and jax.lax.reduce_xor.
    • jax.lax.linalg.qr, and jax.scipy.linalg.qr, now support
      column-pivoting on CPU and GPU. See #20282 and
      #25955 for more details.
  • Changes

    • JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES now work as
      env vars. Before they could only be specified via jax.config or flags.
    • JAX_CPU_COLLECTIVES_IMPLEMENTATION now defaults to 'gloo', meaning
      multi-process CPU communication works out-of-the-box.
    • The jax[tpu] TPU extra no longer depends on the libtpu-nightly package.
      This package may safely be removed if it is present on your machine; JAX now
      uses libtpu instead.
  • Deprecations

    • The internal function linear_util.wrap_init and the constructor
      core.Jaxpr now must take a non-empty core.DebugInfo kwarg. For
      a limited time, a DeprecationWarning is printed if
      jax.extend.linear_util.wrap_init is used without debugging info.
      A downstream effect of this several other internal functions need debug
      info. This change does not affect public APIs.
      See #26480 for more detail.
  • Bug fixes

    • TPU runtime startup and shutdown time should be significantly improved on
      TPU v5e and newer (from around 17s to around 8s). If not already set, you may
      need to enable transparent hugepages in your VM image
      (sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled').
      We hope to improve this further in future releases.
    • Persistent compilation cache no longer writes access time file if
      JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
      eviction policy isn't enabled. This should improve performance when using
      the cache with large-scale network storage.