@@ -1014,6 +1014,150 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
10141014 _launcher(_helion_jagged_mean_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
10151015 return out
10161016
1017+ --- assertExpectedJournal(TestExamples.test_jagged_softmax)
1018+ from __future__ import annotations
1019+
1020+ import torch
1021+ import triton
1022+ import triton.language as tl
1023+ from torch._inductor.runtime import triton_helpers
1024+ from torch._inductor.runtime.triton_helpers import math as tl_math
1025+ from helion.runtime import default_launcher as _default_launcher
1026+
1027+ @triton.jit
1028+ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, out_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1029+ pid_0 = tl.program_id(0)
1030+ offset_0 = pid_0 * _BLOCK_SIZE_0
1031+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1032+ mask_0 = indices_0 < num_rows
1033+ starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
1034+ v_0 = tl.full([], 1, tl.int32)
1035+ v_1 = indices_0 + v_0
1036+ ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
1037+ v_2 = ends - starts
1038+ _mask_to = tl.where(mask_0, v_2, -9223372036854775808)
1039+ max_seqlen = tl.max(_mask_to, 0)
1040+ for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1041+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1042+ mask_1 = indices_1 < M
1043+ max_seqlen_copy = max_seqlen
1044+ starts_copy = starts
1045+ v_2_copy = v_2
1046+ max_seqlen_copy_0 = max_seqlen_copy
1047+ starts_copy_0 = starts_copy
1048+ v_2_copy_0 = v_2_copy
1049+ block_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1050+ block_new_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1051+ block_L = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1052+ for offset_2 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_2):
1053+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1054+ mask_2 = indices_2 < max_seqlen_copy_0
1055+ starts_copy_0_copy = starts_copy_0
1056+ v_2_copy_0_copy = v_2_copy_0
1057+ block_max_copy = block_max
1058+ block_L_copy = block_L
1059+ starts_copy_0_copy_0 = starts_copy_0_copy
1060+ v_2_copy_0_copy_0 = v_2_copy_0_copy
1061+ block_max_copy_0 = block_max_copy
1062+ block_L_copy_0 = block_L_copy
1063+ subscript = starts_copy_0_copy_0[:, None]
1064+ subscript_1 = indices_2[None, :]
1065+ v_3 = subscript_1.to(tl.int64)
1066+ v_4 = subscript + v_3
1067+ subscript_2 = v_4[:, :, None]
1068+ v_5 = subscript_2 * M
1069+ subscript_3 = indices_1[None, None, :]
1070+ v_6 = subscript_3.to(tl.int64)
1071+ v_7 = v_5 + v_6
1072+ subscript_4 = indices_2[None, :]
1073+ subscript_5 = v_2_copy_0_copy_0[:, None]
1074+ v_8 = subscript_4.to(tl.int64)
1075+ v_9 = v_8 < subscript_5
1076+ subscript_6 = v_9[:, :, None]
1077+ v_10 = M.to(tl.int32)
1078+ v_11 = indices_1 < v_10
1079+ subscript_7 = v_11[None, None, :]
1080+ v_12 = subscript_6 & subscript_7
1081+ x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
1082+ v_13 = float('-inf')
1083+ v_14 = v_13[None, None, None]
1084+ v_15 = tl.where(v_12, x_slice, v_14)
1085+ _mask_to_1 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_15, float('-inf'))
1086+ slice_max = tl.max(_mask_to_1, 1)
1087+ block_new_max = triton_helpers.maximum(block_max_copy_0, slice_max)
1088+ v_17 = block_max_copy_0 - block_new_max
1089+ v_18 = tl_math.exp(v_17)
1090+ v_19 = block_L_copy_0 * v_18
1091+ subscript_8 = block_new_max[:, None, :]
1092+ v_20 = x_slice - subscript_8
1093+ v_21 = float('-inf')
1094+ v_22 = v_21[None, None, None]
1095+ v_23 = tl.where(v_12, v_20, v_22)
1096+ v_24 = tl_math.exp(v_23)
1097+ _mask_to_2 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_24, 0)
1098+ sum_1 = tl.sum(_mask_to_2, 1)
1099+ block_L = v_19 + sum_1
1100+ block_max = block_new_max
1101+ for offset_3 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_3):
1102+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1103+ mask_3 = indices_3 < max_seqlen_copy_0
1104+ starts_copy_0_copy_1 = starts_copy_0
1105+ v_2_copy_0_copy_1 = v_2_copy_0
1106+ block_max_copy_1 = block_max
1107+ block_L_copy_1 = block_L
1108+ starts_copy_0_copy_1_0 = starts_copy_0_copy_1
1109+ v_2_copy_0_copy_1_0 = v_2_copy_0_copy_1
1110+ block_max_copy_1_0 = block_max_copy_1
1111+ block_L_copy_1_0 = block_L_copy_1
1112+ subscript_9 = starts_copy_0_copy_1_0[:, None]
1113+ subscript_10 = indices_3[None, :]
1114+ v_26 = subscript_10.to(tl.int64)
1115+ v_27 = subscript_9 + v_26
1116+ subscript_11 = v_27[:, :, None]
1117+ v_28 = subscript_11 * M
1118+ subscript_12 = indices_1[None, None, :]
1119+ v_29 = subscript_12.to(tl.int64)
1120+ v_30 = v_28 + v_29
1121+ subscript_13 = indices_3[None, :]
1122+ subscript_14 = v_2_copy_0_copy_1_0[:, None]
1123+ v_31 = subscript_13.to(tl.int64)
1124+ v_32 = v_31 < subscript_14
1125+ subscript_15 = v_32[:, :, None]
1126+ v_33 = M.to(tl.int32)
1127+ v_34 = indices_1 < v_33
1128+ subscript_16 = v_34[None, None, :]
1129+ v_35 = subscript_15 & subscript_16
1130+ x_slice_1 = tl.load(x_flat + v_30 * x_flat_stride_0, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35, other=0)
1131+ subscript_17 = block_max_copy_1_0[:, None, :]
1132+ v_36 = x_slice_1 - subscript_17
1133+ v_37 = tl_math.exp(v_36)
1134+ subscript_18 = block_L_copy_1_0[:, None, :]
1135+ v_38 = v_37 / subscript_18
1136+ tl.store(out + v_30 * out_stride_0, v_38, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35)
1137+
1138+ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher):
1139+ """
1140+ Compute the per-batch softmax in a jagged tensor.
1141+
1142+ Args:
1143+ x_data: 2-D tensor of shape (total_elements, max_M) holding all elements
1144+ x_offsets: (num_rows + 1) tensor. Row i is the slice
1145+ x_data[x_offsets[i] : x_offsets[i+1], :]
1146+
1147+ Returns:
1148+ 2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores.
1149+ """
1150+ N = int(x_offsets[-1].item())
1151+ num_rows, M = (x_offsets.size(0) - 1, x_data.size(1))
1152+ out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device)
1153+ x_flat = x_data.view(-1)
1154+ _BLOCK_SIZE_0 = 16
1155+ _BLOCK_SIZE_1 = 8
1156+ _BLOCK_SIZE_2 = 16
1157+ _BLOCK_SIZE_3 = 16
1158+ _launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1159+ return out.reshape(N, M)
1160+
10171161--- assertExpectedJournal(TestExamples.test_layernorm)
10181162from __future__ import annotations
10191163
0 commit comments