Skip to content

Commit a70562a

Browse files
authored
[Benchmark] kl_div kernel and test (#615)
1 parent dc1f48e commit a70562a

File tree

4 files changed

+381
-0
lines changed

4 files changed

+381
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ class RunResult:
7777
"examples.jsd",
7878
"jsd_tritonbench",
7979
),
80+
"kl_div": (
81+
"tritonbench.operators.kl_div.operator",
82+
"examples.kl_div",
83+
"kl_div_tritonbench",
84+
),
8085
"ragged_attention": (
8186
"tritonbench.operators.ragged_attention.operator",
8287
"examples.jagged_hstu_attn",
@@ -253,6 +258,14 @@ class RunResult:
253258
"helion_welford-speedup": "helion_speedup",
254259
"helion_welford-accuracy": "helion_accuracy",
255260
},
261+
"kl_div": {
262+
"liger_kl_div-speedup": "triton_speedup",
263+
"liger_kl_div-accuracy": "triton_accuracy",
264+
"torch_compile_kl_div-speedup": "torch_compile_speedup",
265+
"torch_compile_kl_div-accuracy": "torch_compile_accuracy",
266+
"helion_kl_div_tritonbench-speedup": "helion_speedup",
267+
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
268+
},
256269
}
257270

258271

examples/kl_div.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""
2+
Helion KL Divergence Example
3+
============================
4+
This example demonstrates a Helion kernel implementation of Kullback-Leibler Divergence.
5+
KL divergence is commonly used in deep learning for comparing probability distributions:
6+
7+
KL(P || Q) = sum_i P(i) * log(P(i) / Q(i))
8+
9+
When the input is in log-space (as common with log-softmax outputs):
10+
KL(P || Q) = sum_i P(i) * (log(P(i)) - log(Q(i)))
11+
12+
The loss supports different reduction modes:
13+
- 'none': No reduction, returns per-example losses
14+
- 'sum': Sum all losses
15+
- 'mean': Average over all elements
16+
- 'batchmean': Average over batch dimension
17+
18+
Based on liger_kernel's KL divergence implementation used in language models.
19+
"""
20+
21+
# %%
22+
# Imports
23+
# -------
24+
from __future__ import annotations
25+
26+
from typing import TYPE_CHECKING
27+
28+
import torch
29+
from torch import Tensor
30+
import torch.nn as nn
31+
32+
import helion
33+
from helion._testing import run_example
34+
import helion.language as hl
35+
36+
if TYPE_CHECKING:
37+
from collections.abc import Callable
38+
39+
40+
# %%
41+
# KL Divergence Kernel
42+
# --------------------
43+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
44+
def kl_div_forward(
45+
y_pred: Tensor, # input predictions in log-space, shape (BT, V)
46+
y_true: Tensor, # target values, shape (BT, V)
47+
log_target: bool = False,
48+
reduction: str = "batchmean",
49+
eps: float = 1e-10,
50+
) -> Tensor:
51+
"""
52+
Compute KL Divergence loss.
53+
54+
Args:
55+
y_pred: Input predictions in log-space, shape (BT, V)
56+
y_true: Target values (probabilities or log-probabilities), shape (BT, V)
57+
log_target: If True, y_true is in log-space; if False, y_true is probabilities
58+
reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean')
59+
eps: Small value to avoid numerical issues
60+
61+
Returns:
62+
loss: KL divergence loss
63+
"""
64+
BT, V = y_pred.shape
65+
assert y_true.shape == y_pred.shape, (
66+
f"Shape mismatch: {y_true.shape} != {y_pred.shape}"
67+
)
68+
69+
# Initialize loss accumulator
70+
if reduction == "none":
71+
loss = torch.zeros_like(y_pred)
72+
else:
73+
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
74+
75+
kl_loss = torch.zeros_like(y_pred)
76+
77+
# Call register_block_size to know block_size_n outside of the reduction loop.
78+
block_size_n = hl.register_block_size(V)
79+
80+
BT_SIZE = helion.cdiv(BT, BT) # Process all at once for simplicity
81+
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
82+
loss_sum = hl.zeros([tile_bt, block_size_n], dtype=torch.float32)
83+
84+
for tile_v in hl.tile(V, block_size=block_size_n):
85+
y_pred_val = y_pred[tile_bt, tile_v]
86+
y_true_val = y_true[tile_bt, tile_v]
87+
88+
if log_target:
89+
# KL(P || Q) = exp(y_true) * (y_true - y_pred) when both in log-space
90+
prob_true = torch.exp(y_true_val)
91+
kl_loss[tile_bt, tile_v] = prob_true * (y_true_val - y_pred_val)
92+
93+
else:
94+
# KL(P || Q) = y_true * (log(y_true) - y_pred) when y_pred in log-space
95+
log_true = torch.log(torch.clamp(y_true_val, min=eps))
96+
kl_loss[tile_bt, tile_v] = y_true_val * (log_true - y_pred_val)
97+
98+
if reduction == "none":
99+
loss[tile_bt, tile_v] = kl_loss[tile_bt, tile_v]
100+
else:
101+
# Sum over vocabulary dimension
102+
loss_sum += kl_loss[tile_bt, tile_v]
103+
104+
if reduction != "none":
105+
loss[tile_bt] = loss_sum.sum(dim=-1)
106+
107+
# Apply final reduction
108+
if reduction == "batchmean":
109+
final_loss = torch.sum(loss) / BT
110+
elif reduction == "sum":
111+
final_loss = torch.sum(loss, dim=0)
112+
elif reduction == "mean":
113+
final_loss = torch.sum(loss) / (BT * V)
114+
else: # reduction == "none"
115+
final_loss = loss
116+
117+
return final_loss
118+
119+
120+
# %%
121+
# KL Divergence Loss Module
122+
# -------------------------
123+
class HelionKLDivLoss(nn.Module):
124+
"""
125+
Helion implementation of KL Divergence Loss matching PyTorch's KLDivLoss.
126+
127+
KL(P || Q) computes the divergence between target distribution P and input Q.
128+
129+
Args:
130+
reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean')
131+
log_target: If True, target is in log-space; if False, target is probabilities
132+
eps: Small value for numerical stability
133+
"""
134+
135+
def __init__(
136+
self,
137+
reduction: str = "batchmean",
138+
log_target: bool = False,
139+
eps: float = 1e-10,
140+
) -> None:
141+
super().__init__()
142+
self.reduction = reduction
143+
self.log_target = log_target
144+
self.eps = eps
145+
146+
def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
147+
"""
148+
Forward pass computing KL divergence loss.
149+
150+
Args:
151+
input_tensor: Input predictions in log-space, shape (BT, V)
152+
target_tensor: Target values (probabilities or log-probabilities), shape (BT, V)
153+
154+
Returns:
155+
KL divergence loss
156+
"""
157+
return kl_div_forward(
158+
input_tensor, target_tensor, self.log_target, self.reduction, self.eps
159+
)
160+
161+
162+
# %%
163+
# Verification Function
164+
# ---------------------
165+
def check_kl_div_kernel(
166+
B: int,
167+
T: int,
168+
V: int,
169+
reduction: str = "batchmean",
170+
log_target: bool = False,
171+
eps: float = 1e-10,
172+
) -> None:
173+
"""
174+
Verify the KL divergence kernel implementation against PyTorch's baseline.
175+
176+
Args:
177+
B: Batch size
178+
T: Sequence length
179+
V: Vocabulary size
180+
reduction: Reduction mode
181+
log_target: Whether target is in log-space
182+
eps: Small value for numerical stability
183+
"""
184+
# Create test tensors following tritonbench pattern
185+
input_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
186+
dim=-1
187+
)
188+
189+
target_tensor = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
190+
191+
# Test forward pass
192+
helion_kl = HelionKLDivLoss(reduction=reduction, log_target=log_target, eps=eps)
193+
torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=log_target).to(
194+
"cuda"
195+
)
196+
197+
def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
198+
return helion_kl(input_tensor, target_tensor)
199+
200+
def baseline_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
201+
return torch_kl_div(input_tensor, target_tensor)
202+
203+
run_example(helion_wrapper, baseline_wrapper, (input_tensor, target_tensor))
204+
205+
206+
# %%
207+
# Tritonbench Integration
208+
# -----------------------
209+
def kl_div_tritonbench(
210+
tb_op: object, input_tensor: Tensor, target_tensor: Tensor
211+
) -> Callable:
212+
"""
213+
Wrapper for tritonbench that matches its interface.
214+
215+
Args:
216+
tb_op: Tritonbench operator object
217+
input_tensor: Input predictions in log-space
218+
target_tensor: Target values
219+
220+
Returns:
221+
Callable: A callable that runs the KL divergence kernel
222+
"""
223+
helion_kl = HelionKLDivLoss(
224+
reduction="batchmean",
225+
log_target=False, # tritonbench uses probabilities, not log-probabilities
226+
eps=1e-10,
227+
)
228+
229+
return lambda: helion_kl(input_tensor, target_tensor)
230+
231+
232+
# %%
233+
# Main Function
234+
# -------------
235+
def main() -> None:
236+
"""
237+
Main entry point that runs KL divergence kernel verification.
238+
Tests various configurations matching tritonbench settings.
239+
"""
240+
print("Testing KL divergence kernel...")
241+
B = 8
242+
T = 512
243+
reduction = "batchmean"
244+
log_target = False
245+
eps = 1e-10
246+
247+
# Test with vocabulary sizes from tritonbench (2^12 to 2^17)
248+
for V in [2**i for i in range(12, 18)]:
249+
print(
250+
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
251+
)
252+
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
253+
print("✓ KL Div passed")
254+
255+
256+
# %%
257+
if __name__ == "__main__":
258+
main()

test/test_examples.expected

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,92 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
14731473
final_loss = torch.sum(loss)
14741474
return (final_loss, dX)
14751475

1476+
--- assertExpectedJournal(TestExamples.test_kl_div)
1477+
from __future__ import annotations
1478+
1479+
import torch
1480+
import helion
1481+
import triton
1482+
import triton.language as tl
1483+
from torch._inductor.runtime import triton_helpers
1484+
from torch._inductor.runtime.triton_helpers import math as tl_math
1485+
from torch._inductor.runtime.triton_compat import libdevice
1486+
from helion.runtime import default_launcher as _default_launcher
1487+
1488+
@triton.jit
1489+
def _helion_kl_div_forward(y_pred, y_true, kl_loss, loss, kl_loss_stride_0, kl_loss_stride_1, loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
1490+
pid_0 = tl.program_id(0)
1491+
offset_1 = pid_0 * _BLOCK_SIZE_1
1492+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1493+
mask_1 = indices_1 < BT
1494+
loss_sum = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
1495+
for offset_0 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_0):
1496+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1497+
mask_0 = indices_0 < V
1498+
loss_sum_copy = loss_sum
1499+
loss_sum_copy_0 = loss_sum_copy
1500+
y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
1501+
y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
1502+
if log_target:
1503+
y_true_val_copy = y_true_val
1504+
y_pred_val_copy = y_pred_val
1505+
y_true_val_copy_0 = y_true_val_copy
1506+
y_pred_val_copy_0 = y_pred_val_copy
1507+
v_0 = libdevice.exp(y_true_val_copy_0)
1508+
v_1 = y_true_val_copy_0 - y_pred_val_copy_0
1509+
v_2 = v_0 * v_1
1510+
tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_2, mask_1[:, None] & mask_0[None, :])
1511+
_not = not log_target
1512+
if _not:
1513+
y_true_val_copy_1 = y_true_val
1514+
y_pred_val_copy_1 = y_pred_val
1515+
y_true_val_copy_1_0 = y_true_val_copy_1
1516+
y_pred_val_copy_1_0 = y_pred_val_copy_1
1517+
v_3 = triton_helpers.maximum(y_true_val_copy_1_0, eps)
1518+
v_4 = tl_math.log(v_3)
1519+
v_5 = v_4 - y_pred_val_copy_1_0
1520+
v_6 = y_true_val_copy_1_0 * v_5
1521+
tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_6, mask_1[:, None] & mask_0[None, :])
1522+
load_2 = tl.load(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
1523+
loss_sum = loss_sum_copy_0 + load_2
1524+
sum_1 = tl.cast(tl.sum(loss_sum, 1), tl.float32)
1525+
tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
1526+
1527+
def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduction: str='batchmean', eps: float=1e-10, *, _launcher=_default_launcher):
1528+
"""
1529+
Compute KL Divergence loss.
1530+
1531+
Args:
1532+
y_pred: Input predictions in log-space, shape (BT, V)
1533+
y_true: Target values (probabilities or log-probabilities), shape (BT, V)
1534+
log_target: If True, y_true is in log-space; if False, y_true is probabilities
1535+
reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean')
1536+
eps: Small value to avoid numerical issues
1537+
1538+
Returns:
1539+
loss: KL divergence loss
1540+
"""
1541+
BT, V = y_pred.shape
1542+
assert y_true.shape == y_pred.shape, f'Shape mismatch: {y_true.shape} != {y_pred.shape}'
1543+
if reduction == 'none':
1544+
loss = torch.zeros_like(y_pred)
1545+
else:
1546+
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
1547+
kl_loss = torch.zeros_like(y_pred)
1548+
BT_SIZE = helion.cdiv(BT, BT)
1549+
_BLOCK_SIZE_1 = BT_SIZE
1550+
_BLOCK_SIZE_0 = 4096
1551+
_launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, kl_loss, loss, kl_loss.stride(0), kl_loss.stride(1), loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1552+
if reduction == 'batchmean':
1553+
final_loss = torch.sum(loss) / BT
1554+
elif reduction == 'sum':
1555+
final_loss = torch.sum(loss, dim=0)
1556+
elif reduction == 'mean':
1557+
final_loss = torch.sum(loss) / (BT * V)
1558+
else:
1559+
final_loss = loss
1560+
return final_loss
1561+
14761562
--- assertExpectedJournal(TestExamples.test_layernorm_bwd_dwdb)
14771563
from __future__ import annotations
14781564

0 commit comments

Comments
 (0)