@@ -544,6 +544,13 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
544544 dpnp .put_along_axis (dp_a , dp_ai , values , axis )
545545 assert_array_equal (np_a , dp_a )
546546
547+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
548+ @pytest .mark .parametrize ("dt" , [bool , numpy .float32 ])
549+ def test_invalid_indices_dtype (self , xp , dt ):
550+ a = xp .ones ((10 , 10 ))
551+ ind = xp .ones (10 , dtype = dt )
552+ assert_raises (IndexError , xp .put_along_axis , a , ind , 7 , axis = 1 )
553+
547554 @pytest .mark .parametrize ("arr_dt" , get_all_dtypes ())
548555 @pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
549556 def test_broadcast (self , arr_dt , idx_dt ):
@@ -673,66 +680,80 @@ def test_argequivalent(self, func, argfunc, kwargs):
673680 @pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
674681 @pytest .mark .parametrize ("ndim" , list (range (1 , 4 )))
675682 def test_multi_dimensions (self , arr_dt , idx_dt , ndim ):
676- np_a = numpy .arange (4 ** ndim , dtype = arr_dt ).reshape ((4 ,) * ndim )
677- np_ai = numpy .array ([3 , 0 , 2 , 1 ], dtype = idx_dt ).reshape (
683+ a = numpy .arange (4 ** ndim , dtype = arr_dt ).reshape ((4 ,) * ndim )
684+ ind = numpy .array ([3 , 0 , 2 , 1 ], dtype = idx_dt ).reshape (
678685 (1 ,) * (ndim - 1 ) + (4 ,)
679686 )
680-
681- dp_a = dpnp .array (np_a , dtype = arr_dt )
682- dp_ai = dpnp .array (np_ai , dtype = idx_dt )
687+ ia , iind = dpnp .array (a ), dpnp .array (ind )
683688
684689 for axis in range (ndim ):
685- expected = numpy .take_along_axis (np_a , np_ai , axis )
686- result = dpnp .take_along_axis (dp_a , dp_ai , axis )
690+ result = dpnp .take_along_axis (ia , iind , axis )
691+ expected = numpy .take_along_axis (a , ind , axis )
687692 assert_array_equal (expected , result )
688693
689694 @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
690- def test_invalid (self , xp ):
695+ def test_not_enough_indices (self , xp ):
691696 a = xp .ones ((10 , 10 ))
692- ai = xp .ones ((10 , 2 ), dtype = xp .intp )
693-
694- # not enough indices
695697 assert_raises (ValueError , xp .take_along_axis , a , xp .array (1 ), axis = 1 )
696698
697- # bool arrays not allowed
698- assert_raises (
699- IndexError , xp .take_along_axis , a , ai .astype (bool ), axis = 1
700- )
699+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
700+ @pytest .mark .parametrize ("dt" , [bool , numpy .float32 ])
701+ def test_invalid_indices_dtype (self , xp , dt ):
702+ a = xp .ones ((10 , 10 ))
703+ ind = xp .ones ((10 , 2 ), dtype = dt )
704+ assert_raises (IndexError , xp .take_along_axis , a , ind , axis = 1 )
701705
702- # float arrays not allowed
703- assert_raises (
704- IndexError , xp .take_along_axis , a , ai .astype (numpy .float32 ), axis = 1
705- )
706+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
707+ def test_invalid_axis (self , xp ):
708+ a = xp .ones ((10 , 10 ))
709+ ind = xp .ones ((10 , 2 ), dtype = xp .intp )
710+ assert_raises (AxisError , xp .take_along_axis , a , ind , axis = 10 )
706711
707- # invalid axis
708- assert_raises (AxisError , xp .take_along_axis , a , ai , axis = 10 )
712+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
713+ def test_indices_ndim_axis_none (self , xp ):
714+ a = xp .ones ((10 , 10 ))
715+ ind = xp .ones ((10 , 2 ), dtype = xp .intp )
716+ assert_raises (ValueError , xp .take_along_axis , a , ind , axis = None )
709717
710- @pytest .mark .parametrize ("arr_dt " , get_all_dtypes ())
718+ @pytest .mark .parametrize ("a_dt " , get_all_dtypes (no_none = True ))
711719 @pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
712- def test_empty (self , arr_dt , idx_dt ):
713- np_a = numpy .ones ((3 , 4 , 5 ), dtype = arr_dt )
714- np_ai = numpy .ones ((3 , 0 , 5 ), dtype = idx_dt )
715-
716- dp_a = dpnp .array (np_a , dtype = arr_dt )
717- dp_ai = dpnp .array (np_ai , dtype = idx_dt )
720+ def test_empty (self , a_dt , idx_dt ):
721+ a = numpy .ones ((3 , 4 , 5 ), dtype = a_dt )
722+ ind = numpy .ones ((3 , 0 , 5 ), dtype = idx_dt )
723+ ia , iind = dpnp .array (a ), dpnp .array (ind )
718724
719- expected = numpy .take_along_axis (np_a , np_ai , axis = 1 )
720- result = dpnp .take_along_axis (dp_a , dp_ai , axis = 1 )
725+ result = dpnp .take_along_axis (ia , iind , axis = 1 )
726+ expected = numpy .take_along_axis (a , ind , axis = 1 )
721727 assert_array_equal (expected , result )
722728
723- @pytest .mark .parametrize ("arr_dt " , get_all_dtypes ())
729+ @pytest .mark .parametrize ("a_dt " , get_all_dtypes (no_none = True ))
724730 @pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
725- def test_broadcast (self , arr_dt , idx_dt ):
726- np_a = numpy .ones ((3 , 4 , 1 ), dtype = arr_dt )
727- np_ai = numpy .ones ((1 , 2 , 5 ), dtype = idx_dt )
728-
729- dp_a = dpnp .array (np_a , dtype = arr_dt )
730- dp_ai = dpnp .array (np_ai , dtype = idx_dt )
731+ def test_broadcast (self , a_dt , idx_dt ):
732+ a = numpy .ones ((3 , 4 , 1 ), dtype = a_dt )
733+ ind = numpy .ones ((1 , 2 , 5 ), dtype = idx_dt )
734+ ia , iind = dpnp .array (a ), dpnp .array (ind )
731735
732- expected = numpy .take_along_axis (np_a , np_ai , axis = 1 )
733- result = dpnp .take_along_axis (dp_a , dp_ai , axis = 1 )
736+ result = dpnp .take_along_axis (ia , iind , axis = 1 )
737+ expected = numpy .take_along_axis (a , ind , axis = 1 )
734738 assert_array_equal (expected , result )
735739
740+ def test_mode_wrap (self ):
741+ a = numpy .array ([- 2 , - 1 , 0 , 1 , 2 ])
742+ ind = numpy .array ([- 2 , 2 , - 5 , 4 ])
743+ ia , iind = dpnp .array (a ), dpnp .array (ind )
744+
745+ result = dpnp .take_along_axis (ia , iind , axis = 0 , mode = "wrap" )
746+ expected = numpy .take_along_axis (a , ind , axis = 0 )
747+ assert_array_equal (result , expected )
748+
749+ def test_mode_clip (self ):
750+ a = dpnp .array ([- 2 , - 1 , 0 , 1 , 2 ])
751+ ind = dpnp .array ([- 2 , 2 , - 5 , 4 ])
752+
753+ # numpy does not support keyword `mode`
754+ result = dpnp .take_along_axis (a , ind , axis = 0 , mode = "clip" )
755+ assert (result == dpnp .array ([- 2 , 0 , - 2 , 2 ])).all ()
756+
736757
737758@pytest .mark .usefixtures ("allow_fall_back_on_numpy" )
738759def test_choose ():
0 commit comments