@@ -445,6 +445,8 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
445
445
elif isinstance (values , ABCMultiIndex ):
446
446
# Avoid raising in extract_array
447
447
values = np .array (values )
448
+ else :
449
+ values = extract_array (values , extract_numpy = True )
448
450
449
451
comps = _ensure_arraylike (comps )
450
452
comps = extract_array (comps , extract_numpy = True )
@@ -459,11 +461,14 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
459
461
elif needs_i8_conversion (values .dtype ) and not is_object_dtype (comps .dtype ):
460
462
# e.g. comps are integers and values are datetime64s
461
463
return np .zeros (comps .shape , dtype = bool )
464
+ # TODO: not quite right ... Sparse/Categorical
465
+ elif needs_i8_conversion (values .dtype ):
466
+ return isin (comps , values .astype (object ))
462
467
463
- comps , dtype = _ensure_data ( comps )
464
- values , _ = _ensure_data ( values , dtype = dtype )
465
-
466
- f = htable . ismember_object
468
+ elif is_extension_array_dtype ( comps . dtype ) or is_extension_array_dtype (
469
+ values . dtype
470
+ ):
471
+ return isin ( np . asarray ( comps ), np . asarray ( values ))
467
472
468
473
# GH16012
469
474
# Ensure np.in1d doesn't get object types or it *may* throw an exception
@@ -476,23 +481,15 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
476
481
f = lambda c , v : np .logical_or (np .in1d (c , v ), np .isnan (c ))
477
482
else :
478
483
f = np .in1d
479
- elif is_integer_dtype (comps .dtype ):
480
- try :
481
- values = values .astype ("int64" , copy = False )
482
- comps = comps .astype ("int64" , copy = False )
483
- f = htable .ismember_int64
484
- except (TypeError , ValueError , OverflowError ):
485
- values = values .astype (object )
486
- comps = comps .astype (object )
487
-
488
- elif is_float_dtype (comps .dtype ):
489
- try :
490
- values = values .astype ("float64" , copy = False )
491
- comps = comps .astype ("float64" , copy = False )
492
- f = htable .ismember_float64
493
- except (TypeError , ValueError ):
494
- values = values .astype (object )
495
- comps = comps .astype (object )
484
+
485
+ else :
486
+ common = np .find_common_type ([values .dtype , comps .dtype ], [])
487
+ values = values .astype (common , copy = False )
488
+ comps = comps .astype (common , copy = False )
489
+ name = common .name
490
+ if name == "bool" :
491
+ name = "uint8"
492
+ f = getattr (htable , f"ismember_{ name } " )
496
493
497
494
return f (comps , values )
498
495
0 commit comments