Skip to content

Commit 3d8af25

Browse files
authored
[Benchmark] jsd kernel and test (#611)
1 parent 6a50593 commit 3d8af25

File tree

4 files changed

+640
-106
lines changed

4 files changed

+640
-106
lines changed

benchmarks/run.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ class RunResult:
7272
"examples.swiglu",
7373
"swiglu_tritonbench",
7474
),
75+
"jsd": (
76+
"tritonbench.operators.jsd.operator",
77+
"examples.jsd",
78+
"jsd_tritonbench",
79+
),
7580
"ragged_attention": (
7681
"tritonbench.operators.ragged_attention.operator",
7782
"examples.jagged_hstu_attn",
@@ -227,6 +232,14 @@ class RunResult:
227232
"helion_swiglu_tritonbench-speedup": "helion_speedup",
228233
"helion_swiglu_tritonbench-accuracy": "helion_accuracy",
229234
},
235+
"jsd": {
236+
"liger_jsd-speedup": "triton_speedup",
237+
"liger_jsd-accuracy": "triton_accuracy",
238+
"torch_compile_jsd-speedup": "torch_compile_speedup",
239+
"torch_compile_jsd-accuracy": "torch_compile_accuracy",
240+
"helion_jsd_tritonbench-speedup": "helion_speedup",
241+
"helion_jsd_tritonbench-accuracy": "helion_accuracy",
242+
},
230243
}
231244

232245

@@ -273,6 +286,7 @@ def check_and_setup_tritonbench() -> None:
273286
# Clone to benchmarks/tritonbench
274287
benchmarks_dir = Path(__file__).parent
275288
tritonbench_path = benchmarks_dir / "tritonbench"
289+
print(f"Using tritonbench path: {tritonbench_path}")
276290

277291
try:
278292
# Clone the repository if it doesn't exist

examples/jsd.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""
2+
Helion JSD (Jensen-Shannon Divergence) Example
3+
==============================================
4+
This example demonstrates a Helion kernel implementation of Jensen-Shannon Divergence.
5+
JSD is commonly used in knowledge distillation for language models, where:
6+
7+
JSD(beta)(P || Q) = beta * KL(P || M) + (1-beta) * KL(Q || M)
8+
where M = beta * P + (1-beta) * Q is the mixture distribution
9+
10+
The generalized JSD reduces to:
11+
- Forward KL when beta = 0: KL(P || Q)
12+
- Reverse KL when beta = 1: KL(Q || P)
13+
- Symmetric JSD when beta = 0.5
14+
15+
Based on liger_kernel's JSD implementation used for knowledge distillation in language models.
16+
"""
17+
18+
# %%
19+
# Imports
20+
# -------
21+
from __future__ import annotations
22+
23+
from typing import TYPE_CHECKING
24+
25+
import torch
26+
from torch import Tensor
27+
import torch.nn as nn
28+
29+
import helion
30+
from helion._testing import run_example
31+
import helion.language as hl
32+
33+
if TYPE_CHECKING:
34+
from collections.abc import Callable
35+
36+
37+
# %%
38+
# JSD Kernel
39+
# ----------
40+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
41+
def jsd_forward(
42+
_input: Tensor, # student predictions (input) in log-space
43+
target: Tensor, # teacher targets in log-space
44+
shift_labels: Tensor | None = None,
45+
beta: float = 0.5,
46+
ignore_index: int = -100,
47+
) -> tuple[Tensor, Tensor]:
48+
"""
49+
Compute Jensen-Shannon Divergence loss.
50+
51+
Args:
52+
_input: Student predictions in log-space, shape (BT, V)
53+
target: Teacher targets in log-space, shape (BT, V)
54+
shift_labels: Optional labels for masking, shape (BT,)
55+
beta: Coefficient for generalized JSD in [0, 1]
56+
ignore_index: Index to ignore in labels
57+
58+
Returns:
59+
loss: Scalar JSD loss
60+
dX: Gradient of loss wrt input
61+
"""
62+
BT, V = _input.shape
63+
assert target.shape == _input.shape, (
64+
f"Shape mismatch: {target.shape} != {_input.shape}"
65+
)
66+
n_rows = BT
67+
68+
# Create output tensor for accumulating loss
69+
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
70+
dX = torch.empty_like(_input)
71+
72+
# Count non-ignored elements
73+
n_non_ignore = float(BT)
74+
if shift_labels is not None:
75+
n_non_ignore = float((shift_labels != ignore_index).sum().item())
76+
if n_non_ignore == 0:
77+
return torch.zeros(
78+
[], dtype=_input.dtype, device=_input.device
79+
), torch.zeros_like(_input)
80+
81+
# Process each sequence position
82+
BT_SIZE = helion.cdiv(BT, n_rows) # The liger kernel uses 1
83+
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
84+
# Check for label masking
85+
if shift_labels is not None:
86+
if shift_labels[tile_bt] == ignore_index:
87+
for tile_X in hl.tile(V):
88+
dX[tile_bt, tile_X] = 0.0
89+
continue
90+
91+
for tile_v in hl.tile(V):
92+
# Load log probabilities and convert to float32
93+
X = _input[tile_bt, tile_v]
94+
Y = target[tile_bt, tile_v]
95+
X_max = torch.amax(X, dim=0)
96+
Y_max = torch.amax(Y, dim=0)
97+
98+
if beta == 0.0: # Forward KL: KL(P || Q)
99+
Y_shift = Y - Y_max
100+
Y_prob = torch.exp(Y_shift) * torch.exp(
101+
Y_max
102+
) # Compensate for the shift
103+
loss[tile_bt, tile_v] = Y_prob * (Y - X)
104+
dX[tile_bt, tile_v] = -Y_prob
105+
elif beta == 1.0: # Reverse KL: KL(Q || P)
106+
X_shift = X - X_max
107+
X_prob = torch.exp(X_shift) * torch.exp(
108+
X_max
109+
) # Compensate for the shift
110+
loss[tile_bt, tile_v] = X_prob * (X - Y)
111+
dX[tile_bt, tile_v] = loss[tile_bt, tile_v] + X_prob
112+
else: # General JSD: beta*KL(P||M) + (1-beta)*KL(Q||M)
113+
max_val = torch.maximum(X_max, Y_max)
114+
X_shifted = X - max_val
115+
Y_shifted = Y - max_val
116+
117+
exp_max = torch.exp(max_val)
118+
119+
Q = torch.exp(X_shifted) * exp_max # = exp(X)
120+
P = torch.exp(Y_shifted) * exp_max # = exp(Y)
121+
122+
beta_P = beta * P
123+
one_minus_beta_Q = (1 - beta) * Q
124+
M = beta_P + one_minus_beta_Q
125+
log_M = torch.log(
126+
M
127+
) # No need to compensate as M is already in original scale
128+
129+
loss[tile_bt, tile_v] = beta_P * Y + one_minus_beta_Q * X - M * log_M
130+
dX[tile_bt, tile_v] = one_minus_beta_Q * (X - log_M)
131+
132+
# Accumulate over vocabulary dimension
133+
scale = 1.0 / n_non_ignore
134+
loss[tile_bt, tile_v] = loss[tile_bt, tile_v] * scale
135+
dX[tile_bt, tile_v] = dX[tile_bt, tile_v] * scale
136+
137+
# Normalize by number of non-ignored elements, run it on host to match liger_kernel
138+
final_loss = torch.sum(
139+
loss
140+
) # This line raises a warning: helion.exc.TensorOperationInWrapper
141+
142+
return final_loss, dX
143+
144+
145+
# %%
146+
# JSD Loss Module (matches liger_kernel structure)
147+
# ------------------------------------------------
148+
class HelionJSD(nn.Module):
149+
"""
150+
Helion implementation of Jensen-Shannon Divergence matching liger_kernel.LigerJSD structure.
151+
152+
JSD(beta)(P || Q) = beta * KL(P || M) + (1-beta) * KL(Q || M)
153+
where M = beta * P + (1-beta) * Q
154+
155+
Args:
156+
beta: Coefficient beta ∈ [0,1]. When beta=0: forward KL, beta=1: reverse KL, beta=0.5: symmetric JSD
157+
ignore_index: Index to ignore in labels for masking
158+
dtype: Data type for loss computation
159+
"""
160+
161+
def __init__(
162+
self,
163+
beta: float = 0.5,
164+
ignore_index: int = -100,
165+
dtype: torch.dtype = torch.float,
166+
) -> None:
167+
super().__init__()
168+
self.beta = beta
169+
self.ignore_index = ignore_index
170+
self.dtype = dtype
171+
172+
def forward(
173+
self,
174+
_input: Tensor, # student predictions in log-space
175+
target: Tensor, # teacher targets in log-space
176+
shift_labels: Tensor | None = None,
177+
) -> Tensor:
178+
"""
179+
Forward pass computing JSD loss.
180+
181+
Args:
182+
_input: Student predictions in log-space, shape (BT, V)
183+
target: Teacher targets in log-space, shape (BT, V)
184+
shift_labels: Optional labels for masking, shape (BT,)
185+
Returns:
186+
Scalar JSD loss
187+
"""
188+
if shift_labels is not None:
189+
assert shift_labels.shape == (_input.shape[0],), (
190+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
191+
)
192+
shift_labels = shift_labels.contiguous()
193+
loss, dX = jsd_forward(
194+
_input, target, shift_labels, self.beta, self.ignore_index
195+
)
196+
return loss.to(self.dtype)
197+
198+
199+
class TorchJSDBaseline(nn.Module):
200+
"""PyTorch baseline JSD implementation matching tritonbench."""
201+
202+
def __init__(
203+
self,
204+
beta: float = 0.5,
205+
ignore_index: int = -100,
206+
dtype: torch.dtype = torch.float,
207+
) -> None:
208+
super().__init__()
209+
self.kl = nn.KLDivLoss(reduction="none", log_target=True)
210+
self.beta = beta
211+
self.ignore_index = ignore_index
212+
self.dtype = dtype
213+
214+
def forward(
215+
self, log_q: Tensor, log_p: Tensor, label: Tensor | None = None
216+
) -> Tensor:
217+
# Convert to float for computation
218+
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
219+
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
220+
221+
# Mixture distribution
222+
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
223+
224+
# JSD loss
225+
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
226+
1 - self.beta
227+
) * self.kl(torch.log(m), log_q).sum(dim=-1)
228+
229+
if label is not None:
230+
loss = torch.where(label != self.ignore_index, loss, 0.0)
231+
n_non_ignore = (label != self.ignore_index).sum().item()
232+
if n_non_ignore == 0:
233+
loss = torch.tensor(0.0, device=log_q.device, dtype=torch.float)
234+
else:
235+
loss = (loss / n_non_ignore).sum()
236+
else:
237+
loss = (loss / log_q.shape[0]).sum()
238+
239+
return loss.to(self.dtype)
240+
241+
242+
# %%
243+
# Verification Function
244+
# ---------------------
245+
def check_jsd_kernel(
246+
B: int,
247+
T: int,
248+
V: int,
249+
beta: float = 0.5,
250+
ignore_index: int = -100,
251+
use_labels: bool = False,
252+
) -> None:
253+
"""
254+
Verify the JSD kernel implementation against PyTorch's baseline.
255+
256+
Args:
257+
B: Batch size (B)
258+
T: Sequence length (T)
259+
V: Vocabulary size (V)
260+
beta: JSD coefficient
261+
ignore_index: Index to ignore in labels
262+
use_labels: Whether to test with label masking
263+
"""
264+
# Create test tensors
265+
log_q = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(dim=-1)
266+
log_p = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)
267+
268+
shift_labels = None
269+
if use_labels:
270+
shift_labels = torch.randint(0, V, (B,), device="cuda")
271+
# Randomly set some to ignore_index
272+
shift_labels[torch.rand(B, device="cuda") < 0.1] = -100
273+
274+
# Test forward pass only (no gradients for now)
275+
helion_jsd = HelionJSD(beta=beta, ignore_index=ignore_index)
276+
torch_jsd = TorchJSDBaseline(beta=beta, ignore_index=ignore_index)
277+
278+
def helion_wrapper(
279+
log_q: Tensor, log_p: Tensor, shift_labels: Tensor | None = None
280+
) -> Tensor:
281+
return helion_jsd(log_q, log_p, shift_labels)
282+
283+
def baseline_wrapper(
284+
log_q: Tensor, log_p: Tensor, shift_labels: Tensor | None = None
285+
) -> Tensor:
286+
return torch_jsd(log_q, log_p, shift_labels)
287+
288+
run_example(helion_wrapper, baseline_wrapper, (log_q, log_p, shift_labels))
289+
290+
291+
# %%
292+
# Tritonbench Integration
293+
# -----------------------
294+
def jsd_tritonbench(tb_op: object, log_q: Tensor, log_p: Tensor) -> Callable:
295+
"""
296+
Wrapper for tritonbench that matches its interface.
297+
298+
Args:
299+
log_q: Student predictions in log-space
300+
log_p: Teacher targets in log-space
301+
302+
Returns:
303+
Callable: A callable that runs the JSD kernel
304+
"""
305+
306+
baseline_model = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue]
307+
308+
helion_jsd = HelionJSD(
309+
beta=baseline_model.beta,
310+
ignore_index=baseline_model.ignore_index,
311+
dtype=baseline_model.dtype,
312+
)
313+
314+
return lambda: helion_jsd(log_q, log_p)
315+
316+
317+
# %%
318+
# Main Function
319+
# -------------
320+
def main() -> None:
321+
"""
322+
Main entry point that runs JSD kernel verification.
323+
Tests various configurations including different beta values and label masking.
324+
"""
325+
print("Testing JSD kernel...")
326+
B = 4
327+
T = 2048
328+
beta = 0.5
329+
ignore_index = -100
330+
use_labels = False
331+
332+
for V in [2**i for i in range(12, 18)]:
333+
print(
334+
f"Testing JSD: B={B}, T={T}, V={V}, beta={beta}, ignore_index={ignore_index}, labels={use_labels}"
335+
)
336+
check_jsd_kernel(B, T, V, beta, ignore_index, use_labels)
337+
print("✓ JSD passed")
338+
339+
340+
# %%
341+
if __name__ == "__main__":
342+
main()

0 commit comments

Comments
 (0)