1515# limitations under the License.
1616
1717
18- from itertools import chain , product , repeat
18+ import operator
19+ from itertools import chain , repeat
1920
2021import numpy as np
2122from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
@@ -426,6 +427,7 @@ def roll(X, shift, axis=None):
426427 if not isinstance (X , dpt .usm_ndarray ):
427428 raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
428429 if axis is None :
430+ shift = operator .index (shift )
429431 res = dpt .empty (
430432 X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = X .sycl_queue
431433 )
@@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
438440 broadcasted = np .broadcast (shift , axis )
439441 if broadcasted .ndim > 1 :
440442 raise ValueError ("'shift' and 'axis' should be scalars or 1D sequences" )
441- shifts = {ax : 0 for ax in range (X .ndim )}
443+ shifts = [
444+ 0 ,
445+ ] * X .ndim
442446 for sh , ax in broadcasted :
443447 shifts [ax ] += sh
444- rolls = [((np .s_ [:], np .s_ [:]),)] * X .ndim
445- for ax , offset in shifts .items ():
446- offset %= X .shape [ax ] or 1
447- if offset :
448- # (original, result), (original, result)
449- rolls [ax ] = (
450- (np .s_ [:- offset ], np .s_ [offset :]),
451- (np .s_ [- offset :], np .s_ [:offset ]),
452- )
453448
449+ exec_q = X .sycl_queue
454450 res = dpt .empty (
455- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = X . sycl_queue
451+ X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
456452 )
457- hev_list = []
458- for indices in product (* rolls ):
459- arr_index , res_index = zip (* indices )
460- hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
461- src = X [arr_index ], dst = res [res_index ], sycl_queue = X .sycl_queue
462- )
463- hev_list .append (hev )
464-
465- dpctl .SyclEvent .wait_for (hev_list )
453+ ht_e , _ = ti ._copy_usm_ndarray_for_roll_nd (
454+ src = X , dst = res , shifts = shifts , sycl_queue = exec_q
455+ )
456+ ht_e .wait ()
466457 return res
467458
468459
0 commit comments