Skip to content

Unroll loop for scatter_gather #1736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ struct ScatterGatherElementwiseKernelFunctor {
auto wg_id = item.get_group_linear_id();
auto local_id = item.get_local_linear_id();
int idx = nv * wg_id + local_id;

#pragma unroll
for (int i = 0; i < thread_work_size_; ++i) {
if (idx < N_) {
f_(idx);
Expand Down
92 changes: 92 additions & 0 deletions test/microbench/scatter.gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from torch.profiler import profile, ProfilerActivity

device = "xpu"
backward = True

# Define shapes for scatter/gather testing
# (input_shape, index_shape, dim_to_scatter_gather)
shape_list = [
((4096, 8192), (4096, 8192), 1), # Simple 2D case
((2, 4096, 320), (2, 4096, 1), 2), # Scatter/Gather along the last dim
((512, 3136, 128), (512, 1, 128), 1), # Scatter/Gather along the middle dim
((128, 49, 196, 1024), (128, 49, 196, 1), 3), # 4D case, scatter/gather last dim
]

for shape_config in shape_list:
input_shape, index_shape, dim_to_operate = shape_config

for dtype in [torch.bfloat16, torch.float16, torch.float32]:
# Generate input tensor
input_tensor = torch.randn(input_shape, device=device, dtype=dtype)

# Generate index tensor for gather/scatter
# Ensure indices are within valid bounds for the dimension
max_idx_val = input_tensor.shape[dim_to_operate]
index_tensor = torch.randint(0, max_idx_val, index_shape, device=device, dtype=torch.int64)

# Generate source tensor for scatter
# Its shape should match index_tensor in the dimension being scattered into,
# and input_tensor in other dimensions.
scatter_source_shape = list(input_tensor.shape)
for i, dim_size in enumerate(index_shape):
if i == dim_to_operate:
scatter_source_shape[i] = dim_size
scatter_source = torch.randn(scatter_source_shape, device=device, dtype=dtype)

if backward:
input_tensor.requires_grad_(True)
scatter_source.requires_grad_(True)

# Warm-up phase
# Gather operation
gathered_output_warmup = torch.gather(input_tensor, dim_to_operate, index_tensor)
if backward:
gy_gather = torch.empty_like(gathered_output_warmup)
gathered_output_warmup.backward(gy_gather)

# Scatter operation (using out-of-place scatter_ to ensure a fresh tensor for profiling)
scattered_output_warmup = input_tensor.clone().scatter_(dim_to_operate, index_tensor, scatter_source)
if backward:
gy_scatter = torch.empty_like(scattered_output_warmup)
scattered_output_warmup.backward(gy_scatter)

print(
"---"
)
print(
"Testing Scatter/Gather -- input shape:",
input_shape,
"; index shape:",
index_shape,
"; datatype:",
dtype,
"; dim:",
dim_to_operate,
"; backward:",
backward,
)
print(
"---"
)

# Profiling phase
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True
) as prof:
for i in range(20):
# Gather operation
gathered_output = torch.gather(input_tensor, dim_to_operate, index_tensor)
if backward:
gy_gather = torch.empty_like(gathered_output)
gathered_output.backward(gy_gather)

# Scatter operation
# We clone input_tensor each time to avoid modifying the same tensor
# across iterations, which could affect profiling if in-place ops are used.
scattered_output = input_tensor.clone().scatter_(dim_to_operate, index_tensor, scatter_source)
if backward:
gy_scatter = torch.empty_like(scattered_output)
scattered_output.backward(gy_scatter)

print(prof.key_averages().table(sort_by="xpu_time_total"))
Loading