@@ -840,6 +840,77 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
840840 _launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
841841 return out
842842
843+ --- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
844+ from __future__ import annotations
845+
846+ import torch
847+ import triton
848+ import triton.language as tl
849+ from torch._inductor.runtime.triton_helpers import math as tl_math
850+ from helion.runtime import default_launcher as _default_launcher
851+
852+ @triton.jit
853+ def _helion_fused_linear_jsd_fwd(student_logits, teacher_logits, loss, student_input_size_0, student_weight_size_0, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
854+ pid_0 = tl.program_id(0)
855+ offset_0 = pid_0 * _BLOCK_SIZE_0
856+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
857+ mask_0 = indices_0 < student_input_size_0
858+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
859+ mask_1 = indices_1 < student_weight_size_0
860+ load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
861+ v_0 = load / temperature
862+ _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, float('-inf'))
863+ amax = tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1])
864+ v_1 = v_0 - amax
865+ v_2 = tl_math.exp(v_1)
866+ _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, 0)
867+ sum_1 = tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1])
868+ v_3 = tl_math.log(sum_1)
869+ v_4 = v_1 - v_3
870+ load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
871+ v_5 = load_1 / temperature
872+ _mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, float('-inf'))
873+ amax_1 = tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1])
874+ v_6 = v_5 - amax_1
875+ v_7 = tl_math.exp(v_6)
876+ _mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, 0)
877+ sum_2 = tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1])
878+ v_8 = tl_math.log(sum_2)
879+ v_9 = v_6 - v_8
880+ student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
881+ teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
882+ v_10 = tl_math.exp(student_prob_1)
883+ v_11 = tl_math.exp(teacher_prob_1)
884+ v_12 = tl_math.exp(student_prob_1)
885+ v_13 = v_11 - v_12
886+ v_14 = v_13 * beta
887+ v_15 = v_10 + v_14
888+ v_16 = tl_math.log(v_15)
889+ v_17 = teacher_prob_1 - v_16
890+ v_18 = tl_math.exp(teacher_prob_1)
891+ v_19 = v_18 * v_17
892+ _mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, 0)
893+ teacher_div = tl.sum(_mask_to_4, 1)
894+ v_20 = tl_math.log(v_15)
895+ v_21 = student_prob_1 - v_20
896+ v_22 = tl_math.exp(student_prob_1)
897+ v_23 = v_22 * v_21
898+ _mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, 0)
899+ student_div = tl.sum(_mask_to_5, 1)
900+ v_24 = teacher_div - student_div
901+ v_25 = v_24 * beta
902+ v_26 = student_div + v_25
903+ tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)
904+
905+ def fused_linear_jsd_fwd(beta: float, ignore_index: int, temperature: float, student_weight: torch.Tensor, teacher_weight: torch.Tensor, student_input: torch.Tensor, teacher_input: torch.Tensor, *, _launcher=_default_launcher):
906+ student_logits = student_input @ student_weight.T
907+ teacher_logits = teacher_input @ teacher_weight.T
908+ loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
909+ _BLOCK_SIZE_0 = 32
910+ _RDIM_SIZE_1 = triton.next_power_of_2(student_weight.size(0))
911+ _launcher(_helion_fused_linear_jsd_fwd, (triton.cdiv(student_input.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_input.size(0), student_weight.size(0), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
912+ return (loss / student_logits.shape[0]).sum()
913+
843914--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
844915from __future__ import annotations
845916
0 commit comments