2121
2222import dpctl
2323import dpctl .tensor as dpt
24- from dpctl .tensor ._type_utils import _can_cast
24+ from dpctl .tensor ._elementwise_common import _get_dtype
25+ from dpctl .tensor ._type_utils import (
26+ _can_cast ,
27+ _strong_dtype_num_kind ,
28+ _weak_type_num_kind ,
29+ )
2530from dpctl .utils import ExecutionPlacementError
2631
2732_all_dtypes = [
@@ -236,6 +241,21 @@ def test_clip_arg_validation():
236241 with pytest .raises (TypeError ):
237242 dpt .clip (check , x1 , x2 )
238243
244+ with pytest .raises (ValueError ):
245+ dpt .clip (x1 , check , x2 )
246+
247+ with pytest .raises (ValueError ):
248+ dpt .clip (x1 , check )
249+
250+ with pytest .raises (TypeError ):
251+ dpt .clip (x1 , x1 , x2 , out = check )
252+
253+ with pytest .raises (TypeError ):
254+ dpt .clip (x1 , x2 , out = check )
255+
256+ with pytest .raises (TypeError ):
257+ dpt .clip (x1 , out = check )
258+
239259
240260@pytest .mark .parametrize (
241261 "dt1,dt2" , [("i4" , "i4" ), ("i4" , "i2" ), ("i2" , "i4" ), ("i1" , "i2" )]
@@ -608,22 +628,40 @@ def test_clip_max_less_than_min():
608628 assert dpt .all (res == 0 )
609629
610630
611- def test_clip_minmax_weak_types ():
631+ @pytest .mark .parametrize ("dt" , ["?" , "i4" , "f4" , "c8" ])
632+ def test_clip_minmax_weak_types (dt ):
612633 get_queue_or_skip ()
613634
614- x = dpt .zeros (10 , dtype = dpt . bool )
635+ x = dpt .zeros (10 , dtype = dt )
615636 min_list = [False , 0 , 0.0 , 0.0 + 0.0j ]
616637 max_list = [True , 1 , 1.0 , 1.0 + 0.0j ]
638+
617639 for min_v , max_v in zip (min_list , max_list ):
618- if isinstance (min_v , bool ) and isinstance (max_v , bool ):
619- y = dpt .clip (x , min_v , max_v )
620- assert isinstance (y , dpt .usm_ndarray )
640+ st_dt = _strong_dtype_num_kind (dpt .dtype (dt ))
641+ wk_dt1 = _weak_type_num_kind (_get_dtype (min_v , x .sycl_device ))
642+ wk_dt2 = _weak_type_num_kind (_get_dtype (max_v , x .sycl_device ))
643+
644+ if st_dt >= wk_dt1 and st_dt >= wk_dt2 :
645+ r = dpt .clip (x , min_v , max_v )
646+ assert isinstance (r , dpt .usm_ndarray )
621647 else :
622648 with pytest .raises (ValueError ):
623649 dpt .clip (x , min_v , max_v )
624650
651+ if st_dt >= wk_dt1 :
652+ r = dpt .clip (x , min_v )
653+ assert isinstance (r , dpt .usm_ndarray )
654+
655+ r = dpt .clip (x , None , min_v )
656+ assert isinstance (r , dpt .usm_ndarray )
657+ else :
658+ with pytest .raises (ValueError ):
659+ dpt .clip (x , min_v )
660+ with pytest .raises (ValueError ):
661+ dpt .clip (x , None , max_v )
662+
625663
626- def test_clip_max_weak_types ():
664+ def test_clip_max_weak_type_errors ():
627665 get_queue_or_skip ()
628666
629667 x = dpt .zeros (10 , dtype = "i4" )
@@ -635,6 +673,15 @@ def test_clip_max_weak_types():
635673 with pytest .raises (ValueError ):
636674 dpt .clip (x , 2.5 , m )
637675
676+ with pytest .raises (ValueError ):
677+ dpt .clip (x , 2.5 )
678+
679+ with pytest .raises (ValueError ):
680+ dpt .clip (dpt .astype (x , "?" ), 2 )
681+
682+ with pytest .raises (ValueError ):
683+ dpt .clip (dpt .astype (x , "f4" ), complex (2 ))
684+
638685
639686def test_clip_unaligned ():
640687 get_queue_or_skip ()
@@ -653,3 +700,51 @@ def test_clip_none_args():
653700 x = dpt .arange (10 , dtype = "i4" )
654701 r = dpt .clip (x )
655702 assert dpt .all (x == r )
703+
704+
705+ def test_clip_shape_errors ():
706+ get_queue_or_skip ()
707+
708+ x = dpt .ones ((4 , 4 ), dtype = "i4" )
709+ a_min = dpt .ones (5 , dtype = "i4" )
710+ a_max = dpt .ones (5 , dtype = "i4" )
711+
712+ with pytest .raises (ValueError ):
713+ dpt .clip (x , a_min , a_max )
714+
715+ with pytest .raises (ValueError ):
716+ dpt .clip (x , a_min )
717+
718+ with pytest .raises (ValueError ):
719+ dpt .clip (x , 0 , 1 , out = a_min )
720+
721+ with pytest .raises (ValueError ):
722+ dpt .clip (x , 0 , out = a_min )
723+
724+ with pytest .raises (ValueError ):
725+ dpt .clip (x , out = a_min )
726+
727+
728+ def test_clip_compute_follows_data ():
729+ q1 = get_queue_or_skip ()
730+ q2 = get_queue_or_skip ()
731+
732+ x = dpt .ones (10 , dtype = "i4" , sycl_queue = q1 )
733+ a_min = dpt .ones (10 , dtype = "i4" , sycl_queue = q2 )
734+ a_max = dpt .ones (10 , dtype = "i4" , sycl_queue = q1 )
735+ res = dpt .empty_like (x , sycl_queue = q2 )
736+
737+ with pytest .raises (ExecutionPlacementError ):
738+ dpt .clip (x , a_min , a_max )
739+
740+ with pytest .raises (ExecutionPlacementError ):
741+ dpt .clip (x , dpt .ones_like (x ), a_max , out = res )
742+
743+ with pytest .raises (ExecutionPlacementError ):
744+ dpt .clip (x , a_min )
745+
746+ with pytest .raises (ExecutionPlacementError ):
747+ dpt .clip (x , None , a_max , out = res )
748+
749+ with pytest .raises (ExecutionPlacementError ):
750+ dpt .clip (x , out = res )
0 commit comments