Skip to content

Commit e5b46b5

Browse files
authored
Update duck_array_ops.py
1 parent 6ac5cd5 commit e5b46b5

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ def asarray(data, xp=np):
222222

223223
def as_shared_dtype(scalars_or_arrays, xp=np):
224224
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
225-
array_type_cupy = array_type("cupy")
226225
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
227226
extension_array_types = [
228227
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
@@ -234,7 +233,10 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
234233
raise ValueError(
235234
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
236235
)
237-
elif any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
236+
237+
# Avoid calling array_type("cupy") repeatidely in the any check
238+
array_type_cupy = array_type("cupy")
239+
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
238240
import cupy as cp
239241

240242
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]

0 commit comments

Comments
 (0)