1515from __future__ import annotations
1616
1717import enum
18- from typing import Any
18+ from typing import Any , Optional
1919import warnings
2020
21+ from jax ._src .api import device_put
2122from jax import numpy as jnp
2223from jax ._src import array
2324from jax ._src import xla_bridge
2425from jax ._src .lib import xla_client
2526from jax ._src .lib import xla_extension_version
2627from jax ._src .typing import Array
27-
28+ from jax . _src . sharding import Sharding
2829
2930# A set of dtypes that dlpack supports.
3031# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -82,16 +83,108 @@ def to_dlpack(x: Array, take_ownership: bool = False,
8283 x .addressable_data (0 ), stream = stream
8384 ) # type: ignore
8485
86+ def _place_array (_arr , device , dlpack_device , copy ):
87+ if device and dlpack_device != device :
88+ if copy is not None and not copy :
89+ raise ValueError (
90+ f"Specified { device = } which requires a copy since the source device "
91+ f"is { repr (dlpack_device )} , however copy=False. Set copy=True or "
92+ "copy=None to perform the requested operation."
93+ )
94+ else :
95+ return device_put (_arr , device )
96+ if copy :
97+ return jnp .array (_arr , copy = True )
98+ return _arr
99+
100+ def _legacy_from_dlpack (dlpack , device : xla_client .Device | None = None , copy : Optional [bool ] = None ):
101+ preferred_platform = getattr (device , "platform" , None )
102+ if device and preferred_platform == "gpu" :
103+ preferred_platform = "cuda" if "cuda" in device .client .platform_version else "rocm"
104+
105+ cpu_backend = xla_bridge .get_backend ("cpu" )
106+ gpu_backend = None
107+
108+ if preferred_platform in {"cuda" , "rocm" }:
109+ try :
110+ gpu_backend = xla_bridge .get_backend (preferred_platform )
111+ except RuntimeError :
112+ raise TypeError (
113+ f"A { str .upper (preferred_platform )} device was specified, however no "
114+ f"{ str .upper (preferred_platform )} backend was found."
115+ )
85116
86- def from_dlpack (external_array ):
117+ if preferred_platform is None :
118+ try :
119+ gpu_backend = xla_bridge .get_backend ("cuda" )
120+ except RuntimeError :
121+ pass
122+ # Try ROCm if CUDA backend not found
123+ if gpu_backend is None :
124+ try :
125+ gpu_backend = xla_bridge .get_backend ("rocm" )
126+ except RuntimeError :
127+ pass
128+
129+ _arr = jnp .asarray (xla_client ._xla .dlpack_managed_tensor_to_buffer (
130+ dlpack , cpu_backend , gpu_backend ))
131+
132+ return _place_array (_arr , device , _arr .devices ().pop (), copy )
133+
134+ def _from_dlpack (external_array , device : xla_client .Device | None = None , copy : bool | None = None ):
135+ dl_device_type , device_id = external_array .__dlpack_device__ ()
136+ try :
137+ dl_device_platform = {
138+ DLDeviceType .kDLCPU : "cpu" ,
139+ DLDeviceType .kDLCUDA : "cuda" ,
140+ DLDeviceType .kDLROCM : "rocm" ,
141+ }[dl_device_type ]
142+ except TypeError :
143+ # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
144+ # TypeError.
145+ raise TypeError (
146+ "Array passed to from_dlpack is on unsupported device type "
147+ f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
148+
149+ backend = xla_bridge .get_backend (dl_device_platform )
150+ dlpack_device = backend .device_from_local_hardware_id (device_id )
151+ try :
152+ stream = dlpack_device .get_stream_for_external_ready_events ()
153+ except xla_client .XlaRuntimeError as err : # type: ignore
154+ if "UNIMPLEMENTED" in str (err ):
155+ stream = None
156+ else :
157+ raise
158+ dlpack = external_array .__dlpack__ (stream = stream )
159+
160+ _arr = jnp .asarray (xla_client ._xla .dlpack_managed_tensor_to_buffer (
161+ dlpack , dlpack_device , stream ))
162+
163+ return _place_array (_arr , device , dlpack_device , copy )
164+
165+ def from_dlpack (external_array , device : xla_client .Device | Sharding | None = None , copy : bool | None = None ):
87166 """Returns a :class:`~jax.Array` representation of a DLPack tensor.
88167
89- The returned :class:`~jax.Array` shares memory with ``external_array``.
168+ The returned :class:`~jax.Array` shares memory with ``external_array`` if no
169+ device transfer or copy was requested.
90170
91171 Args:
92- external_array: an array object that has __dlpack__ and __dlpack_device__
172+ external_array: An array object that has __dlpack__ and __dlpack_device__
93173 methods, or a DLPack tensor on either CPU or GPU (legacy API).
94174
175+ device: The (optional) :py:class:`Device`, representing the device on which
176+ the returned array should be placed. If given, then the result is committed
177+ to the device. If unspecified, the resulting array will be unpacked onto the
178+ same device it originated from. Setting ``device`` to a device different from
179+ the source of ``external_array`` will require a copy, meaning ``copy`` must be
180+ set to either ``True`` or ``None``.
181+
182+ copy: An (optional) boolean, controlling whether or not to a copy is performed.
183+ If ``copy=True`` then a copy is always performed, even if unpacked onto the
184+ same device. If ``copy=False`` then the copy is never peformed and will raise
185+ an error if necessary. When ``copy=None`` then a copy may be performed if
186+ needed for a device transfer.
187+
95188 Returns:
96189 A jax.Array
97190
@@ -102,49 +195,16 @@ def from_dlpack(external_array):
102195 is later modified in-place, it may lead to undefined behavior when using
103196 the associated JAX array.
104197 """
198+ if isinstance (device , Sharding ):
199+ device_set = device .device_set
200+ if len (device_set ) > 1 :
201+ raise ValueError (
202+ "from_dlpack can only unpack a dlpack tensor onto a singular device, but "
203+ f"a Sharding with { len (device_set )} devices was provided."
204+ )
205+ device = device_set .pop ()
105206 if hasattr (external_array , "__dlpack__" ):
106- dl_device_type , device_id = external_array .__dlpack_device__ ()
107- try :
108- device_platform = {
109- DLDeviceType .kDLCPU : "cpu" ,
110- DLDeviceType .kDLCUDA : "cuda" ,
111- DLDeviceType .kDLROCM : "rocm" ,
112- }[dl_device_type ]
113- except TypeError :
114- # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115- # TypeError.
116- raise TypeError (
117- "Array passed to from_dlpack is on unsupported device type "
118- f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
119-
120- backend = xla_bridge .get_backend (device_platform )
121- device = backend .device_from_local_hardware_id (device_id )
122- try :
123- stream = device .get_stream_for_external_ready_events ()
124- except xla_client .XlaRuntimeError as err : # type: ignore
125- if "UNIMPLEMENTED" in str (err ):
126- stream = None
127- else :
128- raise
129- dlpack = external_array .__dlpack__ (stream = stream )
130-
131- return jnp .asarray (xla_client ._xla .dlpack_managed_tensor_to_buffer (
132- dlpack , device , stream ))
133- else :
134- # Legacy path
135- dlpack = external_array
136- cpu_backend = xla_bridge .get_backend ("cpu" )
137- try :
138- gpu_backend = xla_bridge .get_backend ("cuda" )
139- except RuntimeError :
140- gpu_backend = None
141-
142- # Try ROCm if CUDA backend not found
143- if gpu_backend is None :
144- try :
145- gpu_backend = xla_bridge .get_backend ("rocm" )
146- except RuntimeError :
147- gpu_backend = None
207+ return _from_dlpack (external_array , device , copy )
148208
149- return jnp . asarray ( xla_client . _xla . dlpack_managed_tensor_to_buffer (
150- dlpack , cpu_backend , gpu_backend ) )
209+ # Legacy path
210+ return _legacy_from_dlpack ( external_array , device , copy )
0 commit comments