File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -222,7 +222,6 @@ def asarray(data, xp=np):
222
222
223
223
def as_shared_dtype (scalars_or_arrays , xp = np ):
224
224
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
225
- array_type_cupy = array_type ("cupy" )
226
225
if any (is_extension_array_dtype (x ) for x in scalars_or_arrays ):
227
226
extension_array_types = [
228
227
x .dtype for x in scalars_or_arrays if is_extension_array_dtype (x )
@@ -234,7 +233,10 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
234
233
raise ValueError (
235
234
f"Cannot cast arrays to shared type, found array types { [x .dtype for x in scalars_or_arrays ]} "
236
235
)
237
- elif any (isinstance (x , array_type_cupy ) for x in scalars_or_arrays ):
236
+
237
+ # Avoid calling array_type("cupy") repeatidely in the any check
238
+ array_type_cupy = array_type ("cupy" )
239
+ if any (isinstance (x , array_type_cupy ) for x in scalars_or_arrays ):
238
240
import cupy as cp
239
241
240
242
arrays = [asarray (x , xp = cp ) for x in scalars_or_arrays ]
You can’t perform that action at this time.
0 commit comments