Skip to content

Commit 4d576ec

Browse files
committed
[example] fused_linear_jsd
1 parent bbc2be4 commit 4d576ec

File tree

5 files changed

+250
-1
lines changed

5 files changed

+250
-1
lines changed

benchmarks/run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ class RunResult:
141141
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
142142
],
143143
),
144+
"fused_linear_jsd": (
145+
"tritonbench.operators.fused_linear_jsd.operator",
146+
"examples.fused_linear_jsd",
147+
"fused_linear_jsd_fwd_tritonbench",
148+
),
144149
}
145150

146151

@@ -500,6 +505,8 @@ def helion_method(
500505
attr.settings.force_autotune = True
501506
attr.settings.static_shape = True # pyright: ignore[reportAttributeAccessIssue]
502507

508+
kfunc._self = self # pyright: ignore[reportFunctionMemberAccess]
509+
503510
def _inner() -> Callable[..., Any] | object:
504511
# BENCHMARK HOT PATH, do not add any new logic here
505512
result = kfunc(*args, **kwargs)

examples/fused_linear_jsd.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
Fused Linear JSD Example
3+
===========================
4+
5+
This example demonstrates how to implement a JSD kernel using Helion and
6+
fuse it with a linear layer.
7+
"""
8+
9+
# %%
10+
# Imports
11+
# -------
12+
from __future__ import annotations
13+
14+
import torch
15+
16+
import helion
17+
from helion._testing import run_example
18+
import helion.language as hl
19+
20+
# %%
21+
# Helion Kernel
22+
# -------------------
23+
@helion.kernel()
24+
def fused_linear_jsd_fwd(
25+
beta: float,
26+
ignore_index: int,
27+
temperature: float,
28+
student_weight: torch.Tensor,
29+
teacher_weight: torch.Tensor,
30+
student_input: torch.Tensor,
31+
teacher_input: torch.Tensor,
32+
) -> torch.Tensor:
33+
student_logits = student_input @ student_weight.T
34+
teacher_logits = teacher_input @ teacher_weight.T
35+
loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
36+
for batch in hl.tile(student_logits.shape[0]):
37+
student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1)
38+
teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1)
39+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
40+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
41+
m = torch.exp(student_prob) + beta * (
42+
torch.exp(teacher_prob) - torch.exp(student_prob)
43+
)
44+
teacher_div = torch.nn.functional.kl_div(
45+
torch.log(m), teacher_prob, reduction="none", log_target=True
46+
).sum(dim=-1)
47+
student_div = torch.nn.functional.kl_div(
48+
torch.log(m), student_prob, reduction="none", log_target=True
49+
).sum(dim=-1)
50+
batch_loss = student_div + beta * (teacher_div - student_div)
51+
loss[batch] = batch_loss
52+
return (loss / student_logits.shape[0]).sum()
53+
54+
55+
# %%
56+
# Benchmark Entry Point Function
57+
# -------------------
58+
def fused_linear_jsd_fwd_tritonbench(
59+
student_input: torch.Tensor,
60+
teacher_input: torch.Tensor,
61+
label: torch.Tensor | None = None,
62+
) -> torch.Tensor:
63+
assert label is None
64+
baseline_op = fused_linear_jsd_fwd_tritonbench._self.baseline_op # pyright: ignore[reportFunctionMemberAccess]
65+
beta = baseline_op.jsd.beta
66+
ignore_index = baseline_op.jsd.ignore_index
67+
temperature = baseline_op.temperature
68+
student_weight = baseline_op.student_lin.weight
69+
teacher_weight = baseline_op.teacher_lin.weight
70+
return fused_linear_jsd_fwd(
71+
beta,
72+
ignore_index,
73+
temperature,
74+
student_weight,
75+
teacher_weight,
76+
student_input,
77+
teacher_input,
78+
)
79+
80+
81+
# %%
82+
# Reference Implementation
83+
# --------------------
84+
def fused_linear_jsd_pytorch(
85+
beta: float,
86+
ignore_index: int,
87+
temperature: float,
88+
student_weight: torch.Tensor,
89+
teacher_weight: torch.Tensor,
90+
student_input: torch.Tensor,
91+
teacher_input: torch.Tensor,
92+
) -> torch.Tensor:
93+
student_logits = student_input @ student_weight.T
94+
teacher_logits = teacher_input @ teacher_weight.T
95+
student_prob = torch.log_softmax(student_logits / temperature, dim=-1)
96+
teacher_prob = torch.log_softmax(teacher_logits / temperature, dim=-1)
97+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
98+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
99+
m = torch.exp(student_prob) + beta * (
100+
torch.exp(teacher_prob) - torch.exp(student_prob)
101+
)
102+
teacher_div = torch.nn.functional.kl_div(
103+
torch.log(m), teacher_prob, reduction="none", log_target=True
104+
).sum(dim=-1)
105+
student_div = torch.nn.functional.kl_div(
106+
torch.log(m), student_prob, reduction="none", log_target=True
107+
).sum(dim=-1)
108+
loss = student_div + beta * (teacher_div - student_div)
109+
return (loss / student_logits.shape[0]).sum()
110+
111+
112+
# %%
113+
# Verification Function
114+
# -------------------
115+
def check(m: int, n: int, k: int) -> None:
116+
student_input = torch.rand([m, n], device="cuda", dtype=torch.float)
117+
teacher_input = torch.rand([m, n], device="cuda", dtype=torch.float)
118+
student_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
119+
teacher_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
120+
run_example(
121+
fused_linear_jsd_fwd,
122+
fused_linear_jsd_pytorch,
123+
(0.5, 1, 1.0, student_weight, teacher_weight, student_input, teacher_input),
124+
)
125+
126+
127+
# %%
128+
# Main Function
129+
# -----------
130+
def main() -> None:
131+
check(1024, 4096, 128256)
132+
133+
134+
if __name__ == "__main__":
135+
main()

helion/autotuner/base_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"misaligned address", # CUDA Error
5454
"PassManager::run failed", # Triton Error
5555
"illegal memory access", # CUDA Error
56+
"exceeds triton maximum tensor numel", # Triton Error
5657
],
5758
)
5859
)
@@ -149,7 +150,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
149150
self.log.warning(f"PTXASError compiling config: {config}")
150151
except Exception as e:
151152
msg = str(e)
152-
if not _expected_errors_regexp.search(msg):
153+
if not _expected_errors_regexp.search(msg + str(e.__cause__)):
153154
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
154155
# Surface Triton IR pass failures more prominently for easier bug reports.
155156
if "PassManager::run failed" in msg:

test/test_examples.expected

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,77 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
841841
_launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
842842
return out
843843

844+
--- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
845+
from __future__ import annotations
846+
847+
import torch
848+
import triton
849+
import triton.language as tl
850+
from torch._inductor.runtime.triton_helpers import math as tl_math
851+
from helion.runtime import default_launcher as _default_launcher
852+
853+
@triton.jit
854+
def _helion_fused_linear_jsd_fwd(student_logits, teacher_logits, loss, student_input_size_0, student_weight_size_0, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
855+
pid_0 = tl.program_id(0)
856+
offset_0 = pid_0 * _BLOCK_SIZE_0
857+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
858+
mask_0 = indices_0 < student_input_size_0
859+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
860+
mask_1 = indices_1 < student_weight_size_0
861+
load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
862+
v_0 = load / temperature
863+
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, tl.full([], float('-inf'), tl.float32))
864+
amax = tl.cast(tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
865+
v_1 = v_0 - amax
866+
v_2 = tl_math.exp(v_1)
867+
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, tl.full([], 0, tl.float32))
868+
sum_1 = tl.cast(tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
869+
v_3 = tl_math.log(sum_1)
870+
v_4 = v_1 - v_3
871+
load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
872+
v_5 = load_1 / temperature
873+
_mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, tl.full([], float('-inf'), tl.float32))
874+
amax_1 = tl.cast(tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
875+
v_6 = v_5 - amax_1
876+
v_7 = tl_math.exp(v_6)
877+
_mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, tl.full([], 0, tl.float32))
878+
sum_2 = tl.cast(tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
879+
v_8 = tl_math.log(sum_2)
880+
v_9 = v_6 - v_8
881+
student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
882+
teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
883+
v_10 = tl_math.exp(student_prob_1)
884+
v_11 = tl_math.exp(teacher_prob_1)
885+
v_12 = tl_math.exp(student_prob_1)
886+
v_13 = v_11 - v_12
887+
v_14 = v_13 * beta
888+
v_15 = v_10 + v_14
889+
v_16 = tl_math.log(v_15)
890+
v_17 = teacher_prob_1 - v_16
891+
v_18 = tl_math.exp(teacher_prob_1)
892+
v_19 = v_18 * v_17
893+
_mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, tl.full([], 0, tl.float32))
894+
teacher_div = tl.cast(tl.sum(_mask_to_4, 1), tl.float32)
895+
v_20 = tl_math.log(v_15)
896+
v_21 = student_prob_1 - v_20
897+
v_22 = tl_math.exp(student_prob_1)
898+
v_23 = v_22 * v_21
899+
_mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, tl.full([], 0, tl.float32))
900+
student_div = tl.cast(tl.sum(_mask_to_5, 1), tl.float32)
901+
v_24 = teacher_div - student_div
902+
v_25 = v_24 * beta
903+
v_26 = student_div + v_25
904+
tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)
905+
906+
def fused_linear_jsd_fwd(beta: float, ignore_index: int, temperature: float, student_weight: torch.Tensor, teacher_weight: torch.Tensor, student_input: torch.Tensor, teacher_input: torch.Tensor, *, _launcher=_default_launcher):
907+
student_logits = student_input @ student_weight.T
908+
teacher_logits = teacher_input @ teacher_weight.T
909+
loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
910+
_BLOCK_SIZE_0 = 32
911+
_RDIM_SIZE_1 = triton.next_power_of_2(student_weight.size(0))
912+
_launcher(_helion_fused_linear_jsd_fwd, (triton.cdiv(student_input.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_input.size(0), student_weight.size(0), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
913+
return (loss / student_logits.shape[0]).sum()
914+
844915
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
845916
from __future__ import annotations
846917

test/test_examples.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,41 @@ def test_jagged_hstu_attn(self):
718718
)
719719
)
720720

721+
def test_fused_linear_jsd(self):
722+
beta = 0.5
723+
ignore_index = 1
724+
temperature = 1.0
725+
m, n, k = 64, 128, 256
726+
727+
student_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
728+
teacher_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
729+
student_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
730+
teacher_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
731+
732+
args = (
733+
beta,
734+
ignore_index,
735+
temperature,
736+
student_weight,
737+
teacher_weight,
738+
student_input,
739+
teacher_input,
740+
)
741+
742+
# Import and use the reference implementation
743+
mod = import_path(EXAMPLES_DIR / "fused_linear_jsd.py")
744+
expected = mod.fused_linear_jsd_pytorch(*args)
745+
746+
self.assertExpectedJournal(
747+
check_example(
748+
"fused_linear_jsd",
749+
args,
750+
expected,
751+
fn_name="fused_linear_jsd_fwd",
752+
block_sizes=[32],
753+
)
754+
)
755+
721756

722757
if __name__ == "__main__":
723758
unittest.main()

0 commit comments

Comments
 (0)