@@ -1365,6 +1365,65 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs)
13651365 yield opinfo_core .SampleInput (input_ , args = (src , * args ))
13661366
13671367
1368+ def sample_inputs_scatter_src (op_info , device , dtype , requires_grad , ** kwargs ):
1369+ del op_info
1370+ del kwargs
1371+ make_arg = functools .partial (
1372+ torch_testing .make_tensor , dtype = dtype , device = device , requires_grad = requires_grad
1373+ )
1374+
1375+ # Basic test cases for scatter.src
1376+ cases = [
1377+ # (self_shape, index_shape, src_shape, dim)
1378+ ((5 , 5 ), (2 , 3 ), (2 , 3 ), 0 ), # 2D scatter on dim=0
1379+ ((5 , 5 ), (3 , 2 ), (3 , 2 ), 1 ), # 2D scatter on dim=1
1380+ ((3 , 4 , 5 ), (2 , 2 , 3 ), (2 , 2 , 3 ), 0 ), # 3D scatter on dim=0
1381+ ((3 , 4 , 5 ), (2 , 2 , 3 ), (2 , 2 , 3 ), 1 ), # 3D scatter on dim=1
1382+ ((3 , 4 , 5 ), (2 , 2 , 3 ), (2 , 2 , 3 ), 2 ), # 3D scatter on dim=2
1383+ ((10 ,), (3 ,), (3 ,), 0 ), # 1D scatter
1384+ ]
1385+
1386+ for self_shape , index_shape , src_shape , dim in cases :
1387+ self_tensor = make_arg (self_shape )
1388+ # Create valid indices for the given dimension without duplication
1389+ index_buffer_shape = list (index_shape )
1390+ index_buffer_shape [dim ] = self_shape [dim ]
1391+ index_tensor = torch .rand (index_buffer_shape , device = device ).argsort (dim = dim )[
1392+ tuple (slice (None , d , None ) for d in index_shape )
1393+ ]
1394+ src_tensor = make_arg (src_shape )
1395+ yield opinfo_core .SampleInput (self_tensor , args = (dim , index_tensor , src_tensor ))
1396+
1397+
1398+ def sample_inputs_scatter_value (op_info , device , dtype , requires_grad , ** kwargs ):
1399+ del op_info
1400+ del kwargs
1401+ make_arg = functools .partial (
1402+ torch_testing .make_tensor , dtype = dtype , device = device , requires_grad = requires_grad
1403+ )
1404+
1405+ # Basic test cases for scatter.value
1406+ cases = [
1407+ # (self_shape, index_shape, dim, value)
1408+ ((5 , 5 ), (2 , 3 ), 0 , 1.0 ), # 2D scatter on dim=0 with scalar value
1409+ ((5 , 5 ), (3 , 2 ), 1 , - 2.5 ), # 2D scatter on dim=1 with scalar value
1410+ ((3 , 4 , 5 ), (2 , 2 , 3 ), 0 , 0.0 ), # 3D scatter on dim=0 with scalar value
1411+ ((3 , 4 , 5 ), (2 , 2 , 3 ), 1 , 3.14 ), # 3D scatter on dim=1 with scalar value
1412+ ((3 , 4 , 5 ), (2 , 2 , 3 ), 2 , - 1.0 ), # 3D scatter on dim=2 with scalar value
1413+ ((10 ,), (3 ,), 0 , 5.0 ), # 1D scatter with scalar value
1414+ ]
1415+
1416+ for self_shape , index_shape , dim , value in cases :
1417+ self_tensor = make_arg (self_shape )
1418+ # Create valid indices for the given dimension without duplication
1419+ index_buffer_shape = list (index_shape )
1420+ index_buffer_shape [dim ] = self_shape [dim ]
1421+ index_tensor = torch .rand (index_buffer_shape , device = device ).argsort (dim = dim )[
1422+ tuple (slice (None , d , None ) for d in index_shape )
1423+ ]
1424+ yield opinfo_core .SampleInput (self_tensor , args = (dim , index_tensor , value ))
1425+
1426+
13681427def sample_inputs__scaled_dot_product_flash_attention (
13691428 op_info , device , dtype , requires_grad , ** kwargs
13701429):
@@ -2533,6 +2592,22 @@ def __init__(self):
25332592 sample_inputs_func = sample_inputs_slice_scatter ,
25342593 supports_out = False ,
25352594 ),
2595+ opinfo_core .OpInfo (
2596+ "ops.aten.scatter.src" ,
2597+ op = torch .ops .aten .scatter .src ,
2598+ aten_name = "scatter.src" ,
2599+ dtypes = common_dtype .all_types_and (torch .bfloat16 , torch .half , torch .bool ),
2600+ sample_inputs_func = sample_inputs_scatter_src ,
2601+ supports_out = False ,
2602+ ),
2603+ opinfo_core .OpInfo (
2604+ "ops.aten.scatter.value" ,
2605+ op = torch .ops .aten .scatter .value ,
2606+ aten_name = "scatter.value" ,
2607+ dtypes = common_dtype .all_types_and (torch .bfloat16 , torch .half , torch .bool ),
2608+ sample_inputs_func = sample_inputs_scatter_value ,
2609+ supports_out = False ,
2610+ ),
25362611 opinfo_core .OpInfo (
25372612 "ops.aten._softmax" ,
25382613 op = torch .ops .aten ._softmax , # pylint: disable=protected-access
0 commit comments