Releases: jax-ml/jax
JAX v0.8.0
-
Breaking changes:
- JAX is changing the default
jax.pmapimplementation to one implemented in
terms ofjax.jitandjax.shard_map.jax.pmapis in maintenance mode
and we encourage all new code to usejax.shard_mapdirectly. See the
migration guide for
more information. - The
auto=parameter ofjax.experimental.shard_map.shard_maphas been
removed. This means thatjax.experimental.shard_map.shard_mapno longer
supports nesting. If you want to nest shard_map calls, please use
jax.shard_map. - JAX no longer allows passing objects that support
__jax_array__directly
to, e.g.jit-ed functions. Calljax.numpy.asarrayon them first. jax.numpy.covis now returns NaN for empty arrays ({jax-issue}#32305),
and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).- JAX no longer accepts
Arrayvalues where adtypevalue is expected. Call
.dtypeon these values first. - The deprecated function
jax.interpreters.mlir.custom_callwas
removed. - The
jax.util,jax.extend.ffi, andjax.experimental.host_callback
modules have been removed. All public APIs within these modules were
deprecated and removed in v0.7.0 or earlier. - The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_p
was removed. jax.experimental.multihost_utils.process_allgatherraises an error when
the input is a jax.Array and not fully-addressable andtiled=False. To fix
this, passtiled=Trueto yourprocess_allgatherinvocation.- from
jax.experimental.compilation_cache, the deprecated symbols
is_initializedandinitialize_cachewere removed. - The deprecated function
jax.interpreters.xla.canonicalize_dtype
was removed. jaxlib.hlo_helpershas been removed. Usejax.ffiinstead.- The option
jax_cpu_enable_gloo_collectiveshas been removed. Use
jax_cpu_collectives_implementationinstead. - The previously-deprecated
interpolationargument to
jax.numpy.percentileandjax.numpy.quantilehas been
removed; usemethodinstead. - The JAX-internal
for_loopprimitive was removed. Its functionality,
reading from and writing to refs in the loop body, is now directly
supported byjax.lax.fori_loop. If you need help updating your
code, please file a bug. jax.numpy.trimzerosnow errors for non-1D input.- The
whereargument tojax.numpy.sumand other reductions is now
required to be boolean. Non-boolean values have resulted in a
DeprecationWarningsince JAX v0.5.0. - The deprecated functions in
jax.dlpack,jax.errors,
jax.lib.xla_bridge,jax.lib.xla_client, and
jax.lib.xla_extensionwere removed. jax.interpreters.mlir.dense_bool_arraywas removed. Use MLIR APIs to
construct attributes instead.
- JAX is changing the default
-
Changes
jax.numpy.linalg.eignow returns a namedtuple (with attributes
eigenvaluesandeigenvectors) instead of a plain tuple.jax.gradandjax.vjpwill now round always primals to
float32iffloat64mode is not enabled.jax.dlpack.from_dlpacknow accepts arrays with non-default layouts,
for example, transposed.- The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
newimplementationargument tojax.lax.linalg.eig
({jax-issue}#27265). Theuse_magmaargument is now deprecated in favor
ofimplementation. jax.numpy.trim_zerosnow follows NumPy 2.2 in supporting
multi-dimensional inputs.
-
Deprecations
jax.experimental.enable_x64andjax.experimental.disable_x64
are deprecated in favor of the new non-experimental context manager
jax.enable_x64.jax.experimental.shard_map.shard_mapis deprecated; going forward use
jax.shard_map.jax.experimental.pjit.pjitis deprecated; going forward use
jax.jit.
JAX v0.7.2
-
Breaking changes:
jax.dlpack.from_dlpackno longer accepts a DLPack capsule. This
behavior was deprecated and is now removed. The function must be called
with an array implementing__dlpack__and__dlpack_device__.
-
Changes
-
The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
for NumPy 2.0 support, the minimum supported SciPy version is now 1.13. -
JAX now represents constants in its internal jaxpr representation as a
LiteralArray, which is a private JAX type that duck types as a
numpy.ndarray. This type may be exposed to users viacustom_jvprules,
for example, and may break code that usesisinstance(x, np.ndarray). If
this breaks your code, you may convert these arrays to classic NumPy arrays
usingnp.asarray(x).
-
-
Bug fixes
arr.view(dtype=None)now returns the array unchanged, matching NumPy's
semantics. Previously it returned the array with a float dtype.jax.random.randintnow produces a less-biased distribution for 8-bit and
16-bit integer types ({jax-issue}#27742). To restore the previous biased
behavior, you may temporarily set thejax_safer_randintconfiguration to
False, but note this is a temporary config that will be removed in a
future release.
-
Deprecations:
- The parameters
enable_xlaandnative_serializationforjax2tf.convert
are deprecated and will be removed in a future version of JAX. These were
used for jax2tf with non-native serialization, which has been now removed. - Setting the config state
jax_pmap_no_rank_reductiontoFalseis
deprecated. By default,jax_pmap_no_rank_reductionwill be set toTrue
andjax.pmapshards will not have their rank reduced, keeping the same
rank as their enclosing array.
- The parameters
JAX v0.7.1
-
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
offered free-threading builds on Linux.
-
Changes
- Exposed
jax.set_meshwhich acts as a global setter and a context manager.
Removedjax.sharding.use_meshin favor ofjax.set_mesh. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
supported. jax.lax.dotnow implements the general dot product via the optional
dimension_numbersargument.
- Exposed
-
Deprecations:
jax.lax.zeros_like_arrayis deprecated. Please use
jax.numpy.zeros_likeinstead.- Attempting to import
jax.experimental.host_callbacknow results in
aDeprecationWarning, and will result in anImportErrorstarting in JAX
v0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35. - In
jax.lax.dot, passing theprecisionandpreferred_element_type
arguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad,
jax.interpreters.batching, andjax.interpreters.partial_eval; they
are used rarely if ever outside JAX itself, and most are deprecated without any
public replacement.
JAX v0.7.0
-
New features:
- Added
jax.Pwhich is an alias forjax.sharding.PartitionSpec. - Added
jax.tree.reduce_associative.
- Added
-
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the
migration guide
for more information. - JAX autodiff is switching to using direct linearization by default (instead of
implementing linearization via JVP and partial eval).
See migration guide
for more information. jax.stages.OutInfohas been replaced withjax.ShapeDtypeStruct.jax.jitnow requiresfunto be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in an error
starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum
supported version until July 2026. - Layout API renames:
Layout,.layout,.input_layoutsand.output_layoutshave been
renamed toFormat,.format,.input_formatsand.output_formatsDeviceLocalLayout,.device_local_layouthave been renamed toLayout
and.layout
jax.experimental.shardmodule has been deleted and all the APIs have been
moved to thejax.shardingendpoint. So usejax.sharding.reshard,
jax.sharding.auto_axesandjax.sharding.explicit_axesinstead of their
experimental endpoints.lax.infeedandlax.outfeedwere removed, after being deprecated in
JAX 0.6. Thetransfer_to_infeedandtransfer_from_outfeedmethods were
also removed theDeviceobjects.- The
jax.extend.core.primitives.pjit_pprimitive has been renamed to
jit_p, and itsnameattribute has changed from"pjit"to"jit".
This affects the string representations of jaxprs. The same primitive is no
longer exported from thejax.experimental.pjitmodule. - The (undocumented) function
jax.extend.backend.add_clear_backends_callback
has been removed. Users should usejax.extend.backend.register_backend_cache
instead.
- JAX is migrating from GSPMD to Shardy by default. See the
-
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPESis deprecated; please use the new
jax.dlpack.is_supported_dtypefunction. jax.scipy.special.sph_harmhas been deprecated following a similar
deprecation in SciPy; usejax.scipy.special.sph_harm_yinstead.- From {mod}
jax.interpreters.xla, the previously deprecated symbols
abstractifyandpytype_aval_mappingshave been removed. jax.interpreters.xla.canonicalize_dtypeis deprecated. For
canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype.
For checking whether an object is a valid jax input, prefer
jax.core.valid_jaxtype.- From {mod}
jax.core, the previously deprecated symbolsAxisName,
ConcretizationTypeError,axis_frame,call_p,closed_call_p,
get_type,trace_state_clean,typematch, andtypecheckhave been
removed. - From {mod}
jax.lib.xla_client, the previously deprecated symbols
DeviceAssignment,get_topology_for_devices, andmlir_api_version
have been removed. jax.extend.ffiwas removed after being deprecated in v0.5.0.
Use {mod}jax.ffiinstead.jax.lib.xla_bridge.get_compile_optionsis deprecated, and replaced by
jax.extend.backend.get_compile_options.
- {obj}
JAX v0.6.2
-
New features:
- Added
jax.tree.broadcastwhich implements a pytree prefix broadcasting helper.
- Added
-
Changes
- The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.
JAX v0.6.1
-
New features:
- Added
jax.lax.axis_sizewhich returns the size of the mapped axis
given its name.
- Added
-
Changes
- Additional checking for the versions of CUDA package dependencies was
reenabled, having been accidentally disabled in a previous release. - JAX nightly packages are now published to artifact registry. To install
these packages, see the JAX installation guide. jax.sharding.PartitionSpecno longer inherits from a tuple.jax.ShapeDtypeStructis immutable now. Please use.updatemethod to
update yourShapeDtypeStructinstead of doing in-place updates.
- Additional checking for the versions of CUDA package dependencies was
-
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_pis deprecated, and will be
removed in JAX v0.7.0.
JAX v0.6.0
-
Breaking changes
jax.numpy.arrayno longer acceptsNone. This behavior was
deprecated since November 2023 and is now removed.- Removed the
config.jax_data_dependent_tracing_fallbackconfig option,
which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery. - Removed the
config.jax_eager_pmapconfig option. - Disallow the calling of
lowerandtraceAOT APIs on the result
ofjax.jitif there have been subsequent wrappers applied.
Previously this worked, but silently ignored the wrappers.
The workaround is to applyjax.jitlast among the wrappers,
and similarly forjax.pmap.
See#27873. - The
cuda12_pipextra forjaxhas been removed; usepip install jax[cuda12]
instead.
-
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
supported. - JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously usingpip install jax[cuda12_local]
to install JAX, runpip install jax[cuda12-local]instead. jax.jitnow requiresfunto be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
-
Deprecations
jax.tree_util.build_treeis deprecated. Usejax.tree.unflatten
instead.- Implemented host callback handlers for CPU and GPU devices using XLA's FFI
and removed existing CPU/GPU handlers using XLA's custom call. - All APIs in
jax.lib.xla_extensionare now deprecated. jax.interpreters.mlir.hloandjax.interpreters.mlir.func_dialect,
which were accidental exports, have been removed. If needed, they are
available fromjax.extend.mlir.jax.interpreters.mlir.custom_callis deprecated. The APIs provided by
jax.ffishould be used instead.- The deprecated use of
jax.ffi.ffi_callwith inline arguments is no
longer supported.jax.ffi.ffi_callnow unconditionally returns a
callable. - The following exports in
jax.lib.xla_clientare deprecated:
get_topology_for_devices,heap_profile,mlir_api_version,Client,
CompileOptions,DeviceAssignment,Frame,HloSharding,OpSharding,
Traceback. - The following internal APIs in
jax.utilare deprecated:
HashableFunction,as_hashable_function,cache,safe_map,safe_zip,
split_dict,split_list,split_list_checked,split_merge,subvals,
toposort,unzip2,wrap_name, andwraps. jax.dlpack.to_dlpackhas been deprecated. You can usually pass a JAX
Arraydirectly to thefrom_dlpackfunction of another framework. If you
need the functionality ofto_dlpack, use the__dlpack__attribute of an
array.jax.lax.infeed,jax.lax.infeed_p,jax.lax.outfeed, and
jax.lax.outfeed_pare deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client:ArrayImpl,FftType,PaddingType,
PrimitiveType,XlaBuilder,dtype_to_etype,
ops,register_custom_call_target,shape_from_pyval,Shape,
XlaComputation. - From
jax.lib.xla_extension:ArrayImpl,XlaRuntimeError. - From
jax:jax.treedef_is_leaf,jax.tree_flatten,jax.tree_map,
jax.tree_leaves,jax.tree_structure,jax.tree_transpose, and
jax.tree_unflatten. Replacements can be found injax.treeor
jax.tree_util. - From
jax.core:AxisSize,ClosedJaxpr,EvalTrace,InDBIdx,InputType,
Jaxpr,JaxprEqn,Literal,MapPrimitive,OpaqueTraceState,OutDBIdx,
Primitive,Token,TRACER_LEAK_DEBUGGER_WARNING,Var,concrete_aval,
dedup_referents,escaped_tracer_error,extend_axis_env_nd,full_lower,get_referent,jaxpr_as_fun,join_effects,lattice_join,
leaked_tracer_error,maybe_find_leaked_tracers,raise_to_shaped,
raise_to_shaped_mappings,reset_trace_state,str_eqn_compact,
substitute_vars_in_output_ty,typecompat, andused_axis_names_jaxpr. Most
have no public replacement, though a few are available atjax.extend.core. - The
vectorizedargument tojax.pure_callbackand
jax.ffi.ffi_call. Use thevmap_methodparameter instead.
- From
JAX v0.5.3
-
New Features
- Added a
allow_negative_indicesoption tojax.lax.dynamic_slice,
jax.lax.dynamic_update_sliceand related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size. - Added a
replaceoption tojax.random.categoricalto enable sampling
without replacement.
- Added a
JAX v0.5.2
Patch release of 0.5.1
- Bug fixes
- Fixes TPU metric logging and
tpu-info, which was broken in 0.5.1
- Fixes TPU metric logging and
JAX v0.5.1
-
New Features
- Added an experimental
jax.experimental.custom_dce.custom_dce
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See#25956for more
details. - Added low-level reduction APIs in {mod}
jax.lax:jax.lax.reduce_sum,
jax.lax.reduce_prod,jax.lax.reduce_max,jax.lax.reduce_min,
jax.lax.reduce_and,jax.lax.reduce_or, andjax.lax.reduce_xor. jax.lax.linalg.qr, andjax.scipy.linalg.qr, now support
column-pivoting on CPU and GPU. See #20282 and
#25955 for more details.
- Added an experimental
-
Changes
JAX_CPU_COLLECTIVES_IMPLEMENTATIONandJAX_NUM_CPU_DEVICESnow work as
env vars. Before they could only be specified via jax.config or flags.JAX_CPU_COLLECTIVES_IMPLEMENTATIONnow defaults to'gloo', meaning
multi-process CPU communication works out-of-the-box.- The
jax[tpu]TPU extra no longer depends on thelibtpu-nightlypackage.
This package may safely be removed if it is present on your machine; JAX now
useslibtpuinstead.
-
Deprecations
- The internal function
linear_util.wrap_initand the constructor
core.Jaxprnow must take a non-emptycore.DebugInfokwarg. For
a limited time, aDeprecationWarningis printed if
jax.extend.linear_util.wrap_initis used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See #26480 for more detail.
- The internal function
-
Bug fixes
- TPU runtime startup and shutdown time should be significantly improved on
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled').
We hope to improve this further in future releases. - Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZEis unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.
- TPU runtime startup and shutdown time should be significantly improved on