@@ -841,6 +841,77 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
841841 _launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
842842 return out
843843
844+ --- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
845+ from __future__ import annotations
846+
847+ import torch
848+ import triton
849+ import triton.language as tl
850+ from torch._inductor.runtime.triton_helpers import math as tl_math
851+ from helion.runtime import default_launcher as _default_launcher
852+
853+ @triton.jit
854+ def _helion_fused_linear_jsd_fwd(student_logits, teacher_logits, loss, student_input_size_0, student_weight_size_0, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
855+ pid_0 = tl.program_id(0)
856+ offset_0 = pid_0 * _BLOCK_SIZE_0
857+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
858+ mask_0 = indices_0 < student_input_size_0
859+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
860+ mask_1 = indices_1 < student_weight_size_0
861+ load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
862+ v_0 = load / temperature
863+ _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, tl.full([], float('-inf'), tl.float32))
864+ amax = tl.cast(tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
865+ v_1 = v_0 - amax
866+ v_2 = tl_math.exp(v_1)
867+ _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, tl.full([], 0, tl.float32))
868+ sum_1 = tl.cast(tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
869+ v_3 = tl_math.log(sum_1)
870+ v_4 = v_1 - v_3
871+ load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
872+ v_5 = load_1 / temperature
873+ _mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, tl.full([], float('-inf'), tl.float32))
874+ amax_1 = tl.cast(tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
875+ v_6 = v_5 - amax_1
876+ v_7 = tl_math.exp(v_6)
877+ _mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, tl.full([], 0, tl.float32))
878+ sum_2 = tl.cast(tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
879+ v_8 = tl_math.log(sum_2)
880+ v_9 = v_6 - v_8
881+ student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
882+ teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
883+ v_10 = tl_math.exp(student_prob_1)
884+ v_11 = tl_math.exp(teacher_prob_1)
885+ v_12 = tl_math.exp(student_prob_1)
886+ v_13 = v_11 - v_12
887+ v_14 = v_13 * beta
888+ v_15 = v_10 + v_14
889+ v_16 = tl_math.log(v_15)
890+ v_17 = teacher_prob_1 - v_16
891+ v_18 = tl_math.exp(teacher_prob_1)
892+ v_19 = v_18 * v_17
893+ _mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, tl.full([], 0, tl.float32))
894+ teacher_div = tl.cast(tl.sum(_mask_to_4, 1), tl.float32)
895+ v_20 = tl_math.log(v_15)
896+ v_21 = student_prob_1 - v_20
897+ v_22 = tl_math.exp(student_prob_1)
898+ v_23 = v_22 * v_21
899+ _mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, tl.full([], 0, tl.float32))
900+ student_div = tl.cast(tl.sum(_mask_to_5, 1), tl.float32)
901+ v_24 = teacher_div - student_div
902+ v_25 = v_24 * beta
903+ v_26 = student_div + v_25
904+ tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)
905+
906+ def fused_linear_jsd_fwd(beta: float, ignore_index: int, temperature: float, student_weight: torch.Tensor, teacher_weight: torch.Tensor, student_input: torch.Tensor, teacher_input: torch.Tensor, *, _launcher=_default_launcher):
907+ student_logits = student_input @ student_weight.T
908+ teacher_logits = teacher_input @ teacher_weight.T
909+ loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
910+ _BLOCK_SIZE_0 = 32
911+ _RDIM_SIZE_1 = triton.next_power_of_2(student_weight.size(0))
912+ _launcher(_helion_fused_linear_jsd_fwd, (triton.cdiv(student_input.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_input.size(0), student_weight.size(0), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
913+ return (loss / student_logits.shape[0]).sum()
914+
844915--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
845916from __future__ import annotations
846917
0 commit comments