Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ class RunResult:
("examples.grouped_gemm", "grouped_gemm_jagged_persistent_tritonbench"),
],
),
"fused_linear_jsd": (
"tritonbench.operators.fused_linear_jsd.operator",
"examples.fused_linear_jsd",
"fused_linear_jsd_fwd_tritonbench",
),
# Multiple kernel variants:
"gemm": (
"tritonbench.operators.gemm.operator",
Expand Down
151 changes: 151 additions & 0 deletions examples/fused_linear_jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Fused Linear JSD Example
===========================

This example demonstrates how to implement a JSD kernel using Helion and
fuse it with a linear layer.
"""

# %%
# Imports
# -------
from __future__ import annotations

from typing import Callable

import torch

import helion
from helion._testing import run_example
import helion.language as hl


# %%
# Helion Kernel
# -------------------
@helion.kernel()
def fused_linear_jsd_kernel(
beta: float,
ignore_index: int,
temperature: float,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
) -> torch.Tensor:
loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float)
for batch in hl.tile(student_logits.shape[0]):
student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1)
teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1)
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
m = torch.exp(student_prob) + beta * (
torch.exp(teacher_prob) - torch.exp(student_prob)
)
teacher_div = torch.nn.functional.kl_div(
torch.log(m), teacher_prob, reduction="none", log_target=True
).sum(dim=-1)
student_div = torch.nn.functional.kl_div(
torch.log(m), student_prob, reduction="none", log_target=True
).sum(dim=-1)
batch_loss = student_div + beta * (teacher_div - student_div)
loss[batch] = batch_loss
return (loss / student_logits.shape[0]).sum()


def fused_linear_jsd_fwd(
beta: float,
ignore_index: int,
temperature: float,
student_weight: torch.Tensor,
teacher_weight: torch.Tensor,
student_input: torch.Tensor,
teacher_input: torch.Tensor,
) -> torch.Tensor:
student_logits = student_input @ student_weight.T
teacher_logits = teacher_input @ teacher_weight.T
return fused_linear_jsd_kernel(
beta, ignore_index, temperature, student_logits, teacher_logits
)


# %%
# Benchmark Entry Point Function
# -------------------
def fused_linear_jsd_fwd_tritonbench(
tb_op: object,
student_input: torch.Tensor,
teacher_input: torch.Tensor,
label: torch.Tensor | None = None,
) -> Callable[[], torch.Tensor]:
assert label is None
baseline_op = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue]
beta = baseline_op.jsd.beta
ignore_index = baseline_op.jsd.ignore_index
temperature = baseline_op.temperature
student_weight = baseline_op.student_lin.weight
teacher_weight = baseline_op.teacher_lin.weight
return lambda: fused_linear_jsd_fwd(
beta,
ignore_index,
temperature,
student_weight,
teacher_weight,
student_input,
teacher_input,
)


# %%
# Reference Implementation
# --------------------
def fused_linear_jsd_pytorch(
beta: float,
ignore_index: int,
temperature: float,
student_weight: torch.Tensor,
teacher_weight: torch.Tensor,
student_input: torch.Tensor,
teacher_input: torch.Tensor,
) -> torch.Tensor:
student_logits = student_input @ student_weight.T
teacher_logits = teacher_input @ teacher_weight.T
student_prob = torch.log_softmax(student_logits / temperature, dim=-1)
teacher_prob = torch.log_softmax(teacher_logits / temperature, dim=-1)
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
m = torch.exp(student_prob) + beta * (
torch.exp(teacher_prob) - torch.exp(student_prob)
)
teacher_div = torch.nn.functional.kl_div(
torch.log(m), teacher_prob, reduction="none", log_target=True
).sum(dim=-1)
student_div = torch.nn.functional.kl_div(
torch.log(m), student_prob, reduction="none", log_target=True
).sum(dim=-1)
loss = student_div + beta * (teacher_div - student_div)
return (loss / student_logits.shape[0]).sum()


# %%
# Verification Function
# -------------------
def check(m: int, n: int, k: int) -> None:
student_input = torch.rand([m, n], device="cuda", dtype=torch.float)
teacher_input = torch.rand([m, n], device="cuda", dtype=torch.float)
student_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
teacher_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
run_example(
fused_linear_jsd_fwd,
fused_linear_jsd_pytorch,
(0.5, 1, 1.0, student_weight, teacher_weight, student_input, teacher_input),
)


# %%
# Main Function
# -----------
def main() -> None:
check(1024, 4096, 128256)


if __name__ == "__main__":
main()
70 changes: 70 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, student_logits_size_0, teacher_logits_size_1, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < student_logits_size_0
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
mask_1 = indices_1 < teacher_logits_size_1
load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = load / temperature
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, tl.full([], float('-inf'), tl.float32))
amax = tl.cast(tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_1 = v_0 - amax
v_2 = libdevice.exp(v_1)
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, tl.full([], 0, tl.float32))
sum_1 = tl.cast(tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_3 = tl_math.log(sum_1)
v_4 = v_1 - v_3
load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_5 = load_1 / temperature
_mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, tl.full([], float('-inf'), tl.float32))
amax_1 = tl.cast(tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_6 = v_5 - amax_1
v_7 = libdevice.exp(v_6)
_mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, tl.full([], 0, tl.float32))
sum_2 = tl.cast(tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_8 = tl_math.log(sum_2)
v_9 = v_6 - v_8
student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
v_10 = libdevice.exp(student_prob_1)
v_11 = libdevice.exp(teacher_prob_1)
v_12 = libdevice.exp(student_prob_1)
v_13 = v_11 - v_12
v_14 = v_13 * beta
v_15 = v_10 + v_14
v_16 = tl_math.log(v_15)
v_17 = teacher_prob_1 - v_16
v_18 = libdevice.exp(teacher_prob_1)
v_19 = v_18 * v_17
_mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, tl.full([], 0, tl.float32))
teacher_div = tl.cast(tl.sum(_mask_to_4, 1), tl.float32)
v_20 = tl_math.log(v_15)
v_21 = student_prob_1 - v_20
v_22 = libdevice.exp(student_prob_1)
v_23 = v_22 * v_21
_mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, tl.full([], 0, tl.float32))
student_div = tl.cast(tl.sum(_mask_to_5, 1), tl.float32)
v_24 = teacher_div - student_div
v_25 = v_24 * beta
v_26 = student_div + v_25
tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)

def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, _launcher=_default_launcher):
loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = triton.next_power_of_2(teacher_logits.size(1))
_launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(student_logits.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_logits.size(0), teacher_logits.size(1), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return (loss / student_logits.shape[0]).sum()

--- assertExpectedJournal(TestExamples.test_gather_gemv)
from __future__ import annotations

Expand Down
37 changes: 37 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,43 @@ def test_int4_gemm(self):
)
)

def test_fused_linear_jsd(self):
beta = 0.5
ignore_index = 1
temperature = 1.0
m, n, k = 64, 128, 256

student_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
teacher_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
student_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
teacher_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
student_logits = student_input @ student_weight.T
teacher_logits = teacher_input @ teacher_weight.T

args = (
beta,
ignore_index,
temperature,
student_logits,
teacher_logits,
)

# Import and use the reference implementation
mod = import_path(EXAMPLES_DIR / "fused_linear_jsd.py")
expected = mod.fused_linear_jsd_pytorch(
*args[:-2], student_input, teacher_input, student_weight, teacher_weight
)

self.assertExpectedJournal(
check_example(
"fused_linear_jsd",
args,
expected,
fn_name="fused_linear_jsd_kernel",
block_sizes=[32],
)
)


if __name__ == "__main__":
unittest.main()
Loading