Skip to content

Commit cc9f8d6

Browse files
v0i0markc-614
authored andcommitted
1 parent 0ae36cb commit cc9f8d6

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13697,6 +13697,28 @@ def f(a_list):
1369713697
print(profile_output)
1369813698
self.assertFalse("Pageable" in profile_output)
1369913699

13700+
@unittest.skipIf(
13701+
config.cpp_wrapper,
13702+
"cpp_wrapper samples will lead to invalid indexing",
13703+
)
13704+
def test_inductor_triton_bucketize_respects_masking(self):
13705+
def fn(inp, repeats, output_size):
13706+
# return torch.repeat_interleave(inp, repeats, dim=0, output_size=output_size)
13707+
idx = torch.searchsorted(
13708+
repeats.cumsum(0),
13709+
torch.arange(0, output_size, device=repeats.device),
13710+
right=True,
13711+
)
13712+
return torch.index_select(inp, 0, idx)
13713+
13714+
inp = torch.arange(0, 4, device=self.device)
13715+
repeats = torch.tensor([1, 2, 3, 4], device=self.device)
13716+
output_size = repeats.sum().item()
13717+
args = (inp, repeats, output_size)
13718+
self.assertEqual(fn(*args), torch.compile(fn)(*args))
13719+
13720+
# end of class CommonTemplate - add new tests here
13721+
1370013722

1370113723
@dataclasses.dataclass
1370213724
class TestFailure:

torch/_inductor/codegen/triton.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,6 +2669,18 @@ def guard_cooperative_store(self, name, buffer):
26692669
buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):"))
26702670
return buffer.indent()
26712671

2672+
def _combine_masks(self, *variables: Optional[CSEVariable]):
2673+
masks = None
2674+
for elem in variables:
2675+
if elem is None:
2676+
continue
2677+
if hasattr(elem, "mask_vars"):
2678+
if masks is None:
2679+
masks = elem.mask_vars
2680+
else:
2681+
masks = masks | elem.mask_vars
2682+
return masks
2683+
26722684
def bucketize(
26732685
self,
26742686
values: CSEVariable,
@@ -2718,6 +2730,9 @@ def bucketize(
27182730
dtype=indexing_dtype, # type: ignore[attr-defined]
27192731
)
27202732

2733+
masks = self._combine_masks(values, boundary_indices, sorter_indices)
2734+
result.mask_vars = masks # type: ignore[attr-defined]
2735+
27212736
return result
27222737

27232738
def reduction_resize(self, value) -> str:

0 commit comments

Comments
 (0)