@@ -395,26 +395,23 @@ def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict:
395
395
dtype = dtypes .pop ()
396
396
if is_categorical_dtype (dtype ):
397
397
result [name ] = union_categoricals (arrs , sort_categories = False )
398
+ elif isinstance (dtype , ExtensionDtype ):
399
+ # TODO: concat_compat?
400
+ array_type = dtype .construct_array_type ()
401
+ # error: Argument 1 to "_concat_same_type" of "ExtensionArray"
402
+ # has incompatible type "List[Union[ExtensionArray, ndarray]]";
403
+ # expected "Sequence[ExtensionArray]"
404
+ result [name ] = array_type ._concat_same_type (arrs ) # type: ignore[arg-type]
398
405
else :
399
- if isinstance (dtype , ExtensionDtype ):
400
- # TODO: concat_compat?
401
- array_type = dtype .construct_array_type ()
402
- # error: Argument 1 to "_concat_same_type" of "ExtensionArray"
403
- # has incompatible type "List[Union[ExtensionArray, ndarray]]";
404
- # expected "Sequence[ExtensionArray]"
405
- result [name ] = array_type ._concat_same_type (
406
- arrs # type: ignore[arg-type]
407
- )
408
- else :
409
- # error: Argument 1 to "concatenate" has incompatible
410
- # type "List[Union[ExtensionArray, ndarray[Any, Any]]]"
411
- # ; expected "Union[_SupportsArray[dtype[Any]],
412
- # Sequence[_SupportsArray[dtype[Any]]],
413
- # Sequence[Sequence[_SupportsArray[dtype[Any]]]],
414
- # Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]]
415
- # , Sequence[Sequence[Sequence[Sequence[
416
- # _SupportsArray[dtype[Any]]]]]]]"
417
- result [name ] = np .concatenate (arrs ) # type: ignore[arg-type]
406
+ # error: Argument 1 to "concatenate" has incompatible
407
+ # type "List[Union[ExtensionArray, ndarray[Any, Any]]]"
408
+ # ; expected "Union[_SupportsArray[dtype[Any]],
409
+ # Sequence[_SupportsArray[dtype[Any]]],
410
+ # Sequence[Sequence[_SupportsArray[dtype[Any]]]],
411
+ # Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]]
412
+ # , Sequence[Sequence[Sequence[Sequence[
413
+ # _SupportsArray[dtype[Any]]]]]]]"
414
+ result [name ] = np .concatenate (arrs ) # type: ignore[arg-type]
418
415
419
416
if warning_columns :
420
417
warning_names = "," .join (warning_columns )
0 commit comments