@@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch
469469 _launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
470470 return out
471471
472+ --- assertExpectedJournal(TestExamples.test_cross_entropy)
473+ from __future__ import annotations
474+
475+ import torch
476+ import triton
477+ import triton.language as tl
478+ from torch._inductor.runtime.triton_helpers import math as tl_math
479+ from helion.runtime import default_launcher as _default_launcher
480+
481+ @triton.jit
482+ def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
483+ pid_0 = tl.program_id(0)
484+ offset_0 = pid_0
485+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
486+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
487+ mask_1 = indices_1 < v
488+ labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
489+ v_0 = v.to(tl.int32)
490+ v_1 = indices_0 * v_0
491+ v_2 = v_1.to(tl.int64)
492+ v_3 = v_2 + labels_tile
493+ logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None)
494+ logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
495+ _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
496+ max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
497+ v_4 = logits_rows - max_logits
498+ v_5 = tl_math.exp(v_4)
499+ _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0)
500+ sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
501+ squeeze = tl.reshape(max_logits, [1])
502+ squeeze_1 = tl.reshape(sum_exp, [1])
503+ v_6 = tl_math.log(squeeze_1)
504+ v_7 = squeeze + v_6
505+ v_8 = v_7 - logits_at_target
506+ tl.store(losses + indices_0 * losses_stride_0, v_8, None)
507+
508+ def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
509+ n, v = logits.shape
510+ losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
511+ logits_flat = logits.view(-1)
512+ _RDIM_SIZE_1 = triton.next_power_of_2(v)
513+ _launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
514+ return losses.mean()
515+
472516--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
473517from __future__ import annotations
474518
@@ -530,6 +574,94 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
530574 _launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
531575 return out.view(*x.size(), embedding_dim)
532576
577+ --- assertExpectedJournal(TestExamples.test_fp8_attention)
578+ from __future__ import annotations
579+
580+ import math
581+ import torch
582+ import triton
583+ import triton.language as tl
584+ from torch._inductor.runtime import triton_helpers
585+ from torch._inductor.runtime.triton_compat import libdevice
586+
587+ @triton.jit
588+ def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
589+ pid_0 = tl.program_id(0)
590+ offset_0 = pid_0
591+ indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
592+ for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
593+ indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
594+ m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
595+ l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
596+ acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
597+ q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
598+ for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
599+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
600+ q_tile_copy = q_tile
601+ m_i_copy = m_i
602+ l_i_copy = l_i
603+ acc_copy = acc
604+ q_tile_copy_0 = q_tile_copy
605+ m_i_copy_0 = m_i_copy
606+ l_i_copy_0 = l_i_copy
607+ acc_copy_0 = acc_copy
608+ k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
609+ k_tile_t = tl.permute(k_tile, [1, 0])
610+ mm = tl.dot(q_tile_copy_0, k_tile_t)
611+ v_0 = mm.to(tl.float32)
612+ v_1 = 0.18033688
613+ v_2 = v_0 * v_1
614+ qk_max = tl.max(v_2, 1)
615+ v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
616+ subscript = v_3[:, None]
617+ v_4 = v_2 - subscript
618+ v_5 = libdevice.exp2(v_4)
619+ l_ij = tl.sum(v_5, 1)
620+ v_6 = m_i_copy_0 - v_3
621+ v_7 = libdevice.exp2(v_6)
622+ v_8 = l_i_copy_0 * v_7
623+ l_i = v_8 + l_ij
624+ subscript_1 = v_7[:, None]
625+ v_10 = acc_copy_0 * subscript_1
626+ v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
627+ v_11 = v_5.to(tl.float8e5)
628+ v_t = tl.permute(v_tile, [1, 0])
629+ mm_1 = tl.dot(v_11, v_t)
630+ v_12 = mm_1.to(tl.float32)
631+ acc = v_10 + v_12
632+ m_i = v_3
633+ subscript_2 = l_i[:, None]
634+ v_14 = acc / subscript_2
635+ tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None)
636+
637+ def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
638+ """FP8 attention kernel processing batch*heads in parallel."""
639+ batch_heads = q.size(0)
640+ seq_len = q.size(1)
641+ head_dim = q.size(2)
642+ out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
643+ sm_scale = 1.0 / math.sqrt(float(head_dim))
644+ sm_scale = sm_scale * 1.44269504
645+ _RDIM_SIZE_2 = 64
646+ _BLOCK_SIZE_1 = 64
647+ _BLOCK_SIZE_3 = 64
648+ _fp8_attention_kernel_kernel[8,](q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
649+ return out
650+
651+ def _fp8_attention_kernel_make_precompiler(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
652+ """FP8 attention kernel processing batch*heads in parallel."""
653+ batch_heads = q.size(0)
654+ seq_len = q.size(1)
655+ head_dim = q.size(2)
656+ out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
657+ sm_scale = 1.0 / math.sqrt(float(head_dim))
658+ sm_scale = sm_scale * 1.44269504
659+ _RDIM_SIZE_2 = 64
660+ _BLOCK_SIZE_1 = 64
661+ _BLOCK_SIZE_3 = 64
662+ from helion.runtime.precompile_shim import make_precompiler
663+ return make_precompiler(_fp8_attention_kernel_kernel)(q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
664+
533665--- assertExpectedJournal(TestExamples.test_fp8_gemm)
534666from __future__ import annotations
535667
@@ -762,6 +894,139 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
762894 _launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
763895 return out
764896
897+ --- assertExpectedJournal(TestExamples.test_jagged_mean_2d)
898+ from __future__ import annotations
899+
900+ import torch
901+ import triton
902+ import triton.language as tl
903+
904+ @triton.jit
905+ def _jagged_mean_kernel_2d_kernel(x_offsets, x_feature_counts, x_flat, out, out_stride_0, out_stride_1, x_feature_counts_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, max_M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
906+ pid_0 = tl.program_id(0)
907+ offset_0 = pid_0 * _BLOCK_SIZE_0
908+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
909+ mask_0 = indices_0 < num_rows
910+ starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
911+ v_0 = tl.full([], 1, tl.int32)
912+ v_1 = indices_0 + v_0
913+ ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
914+ v_2 = ends - starts
915+ _mask_to = tl.where(mask_0, v_2, -9223372036854775808)
916+ max_nnz = tl.max(_mask_to, 0)
917+ feature_counts = tl.load(x_feature_counts + indices_0 * x_feature_counts_stride_0, mask_0, other=0)
918+ for offset_1 in tl.range(0, max_M.to(tl.int32), step=_BLOCK_SIZE_1):
919+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
920+ mask_1 = indices_1 < max_M
921+ feature_counts_copy = feature_counts
922+ max_nnz_copy = max_nnz
923+ starts_copy = starts
924+ v_2_copy = v_2
925+ feature_counts_copy_0 = feature_counts_copy
926+ max_nnz_copy_0 = max_nnz_copy
927+ starts_copy_0 = starts_copy
928+ v_2_copy_0 = v_2_copy
929+ subscript = feature_counts_copy_0[:, None]
930+ v_3 = indices_1[None, :]
931+ v_4 = v_3 < subscript
932+ row_sums = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
933+ for offset_2 in tl.range(0, max_nnz_copy_0.to(tl.int32), step=_BLOCK_SIZE_2):
934+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
935+ mask_2 = indices_2 < max_nnz_copy_0
936+ starts_copy_0_copy = starts_copy_0
937+ v_2_copy_0_copy = v_2_copy_0
938+ v_4_copy = v_4
939+ row_sums_copy = row_sums
940+ starts_copy_0_copy_0 = starts_copy_0_copy
941+ v_2_copy_0_copy_0 = v_2_copy_0_copy
942+ v_4_copy_0 = v_4_copy
943+ row_sums_copy_0 = row_sums_copy
944+ subscript_1 = starts_copy_0_copy_0[:, None]
945+ subscript_2 = indices_2[None, :]
946+ v_5 = subscript_2.to(tl.int64)
947+ v_6 = subscript_1 + v_5
948+ subscript_3 = v_6[:, :, None]
949+ v_7 = subscript_3 * max_M
950+ subscript_4 = indices_1[None, None, :]
951+ v_8 = subscript_4.to(tl.int64)
952+ v_9 = v_7 + v_8
953+ subscript_5 = indices_2[None, :]
954+ subscript_6 = v_2_copy_0_copy_0[:, None]
955+ v_10 = subscript_5.to(tl.int64)
956+ v_11 = v_10 < subscript_6
957+ subscript_7 = v_11[:, :, None]
958+ subscript_8 = v_4_copy_0[:, None, :]
959+ v_12 = subscript_7 & subscript_8
960+ x_slice = tl.load(x_flat + v_9 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
961+ sum_1 = tl.sum(x_slice, 1)
962+ row_sums = row_sums_copy_0 + sum_1
963+ v_14 = v_2_copy_0.to(tl.float32)
964+ nnz_expanded = v_14[:, None]
965+ v_15 = 0.0
966+ v_16 = nnz_expanded > v_15
967+ v_17 = row_sums / nnz_expanded
968+ v_18 = 0.0
969+ v_19 = v_18[None, None]
970+ v_20 = tl.where(v_16, v_17, v_19)
971+ v_21 = 0.0
972+ v_22 = v_21[None, None]
973+ v_23 = tl.where(v_4, v_20, v_22)
974+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_23, mask_0[:, None] & mask_1[None, :])
975+
976+ def jagged_mean_kernel_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M_tensor: torch.Tensor):
977+ """
978+ Compute the mean of each row in a 2D jagged tensor with variable features per row.
979+
980+ Args
981+ ----
982+ x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
983+ x_offsets : (num_rows + 1) tensor. Row i is the slice
984+ x_data[x_offsets[i] : x_offsets[i+1], :].
985+ x_feature_counts: (num_rows) tensor. Number of valid features for each row.
986+ max_M_tensor : Dummy tensor whose numel() gives max number of features.
987+
988+ Returns
989+ -------
990+ result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
991+ Invalid features (beyond x_feature_counts[i]) are set to 0.
992+ """
993+ num_rows = x_offsets.size(0) - 1
994+ max_M = max_M_tensor.numel()
995+ out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
996+ x_flat = x_data.view(-1)
997+ _BLOCK_SIZE_0 = 16
998+ _BLOCK_SIZE_1 = 8
999+ _BLOCK_SIZE_2 = 16
1000+ _jagged_mean_kernel_2d_kernel[triton.cdiv(num_rows, _BLOCK_SIZE_0),](x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1001+ return out
1002+
1003+ def _jagged_mean_kernel_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M_tensor: torch.Tensor):
1004+ """
1005+ Compute the mean of each row in a 2D jagged tensor with variable features per row.
1006+
1007+ Args
1008+ ----
1009+ x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
1010+ x_offsets : (num_rows + 1) tensor. Row i is the slice
1011+ x_data[x_offsets[i] : x_offsets[i+1], :].
1012+ x_feature_counts: (num_rows) tensor. Number of valid features for each row.
1013+ max_M_tensor : Dummy tensor whose numel() gives max number of features.
1014+
1015+ Returns
1016+ -------
1017+ result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
1018+ Invalid features (beyond x_feature_counts[i]) are set to 0.
1019+ """
1020+ num_rows = x_offsets.size(0) - 1
1021+ max_M = max_M_tensor.numel()
1022+ out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
1023+ x_flat = x_data.view(-1)
1024+ _BLOCK_SIZE_0 = 16
1025+ _BLOCK_SIZE_1 = 8
1026+ _BLOCK_SIZE_2 = 16
1027+ from helion.runtime.precompile_shim import make_precompiler
1028+ return make_precompiler(_jagged_mean_kernel_2d_kernel)(x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1029+
7651030--- assertExpectedJournal(TestExamples.test_matmul)
7661031from __future__ import annotations
7671032
0 commit comments