Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class RunResult:
"examples.jsd",
"jsd_tritonbench",
),
"kl_div": (
"tritonbench.operators.kl_div.operator",
"examples.kl_div",
"kl_div_tritonbench",
),
"ragged_attention": (
"tritonbench.operators.ragged_attention.operator",
"examples.jagged_hstu_attn",
Expand Down Expand Up @@ -253,6 +258,14 @@ class RunResult:
"helion_welford-speedup": "helion_speedup",
"helion_welford-accuracy": "helion_accuracy",
},
"kl_div": {
"liger_kl_div-speedup": "triton_speedup",
"liger_kl_div-accuracy": "triton_accuracy",
"torch_compile_kl_div-speedup": "torch_compile_speedup",
"torch_compile_kl_div-accuracy": "torch_compile_accuracy",
"helion_kl_div_tritonbench-speedup": "helion_speedup",
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
},
}


Expand Down
258 changes: 258 additions & 0 deletions examples/kl_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""
Helion KL Divergence Example
============================
This example demonstrates a Helion kernel implementation of Kullback-Leibler Divergence.
KL divergence is commonly used in deep learning for comparing probability distributions:

KL(P || Q) = sum_i P(i) * log(P(i) / Q(i))

When the input is in log-space (as common with log-softmax outputs):
KL(P || Q) = sum_i P(i) * (log(P(i)) - log(Q(i)))

The loss supports different reduction modes:
- 'none': No reduction, returns per-example losses
- 'sum': Sum all losses
- 'mean': Average over all elements
- 'batchmean': Average over batch dimension

Based on liger_kernel's KL divergence implementation used in language models.
"""

# %%
# Imports
# -------
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch import Tensor
import torch.nn as nn

import helion
from helion._testing import run_example
import helion.language as hl

if TYPE_CHECKING:
from collections.abc import Callable


# %%
# KL Divergence Kernel
# --------------------
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
def kl_div_forward(
y_pred: Tensor, # input predictions in log-space, shape (BT, V)
y_true: Tensor, # target values, shape (BT, V)
log_target: bool = False,
reduction: str = "batchmean",
eps: float = 1e-10,
) -> Tensor:
"""
Compute KL Divergence loss.

Args:
y_pred: Input predictions in log-space, shape (BT, V)
y_true: Target values (probabilities or log-probabilities), shape (BT, V)
log_target: If True, y_true is in log-space; if False, y_true is probabilities
reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean')
eps: Small value to avoid numerical issues

Returns:
loss: KL divergence loss
"""
BT, V = y_pred.shape
assert y_true.shape == y_pred.shape, (
f"Shape mismatch: {y_true.shape} != {y_pred.shape}"
)

# Initialize loss accumulator
if reduction == "none":
loss = torch.zeros_like(y_pred)
else:
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)

kl_loss = torch.zeros_like(y_pred)

# Call register_block_size to know block_size_n outside of the reduction loop.
block_size_n = hl.register_block_size(V)

BT_SIZE = helion.cdiv(BT, BT) # Process all at once for simplicity
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
loss_sum = hl.zeros([tile_bt, block_size_n], dtype=torch.float32)

for tile_v in hl.tile(V, block_size=block_size_n):
y_pred_val = y_pred[tile_bt, tile_v]
y_true_val = y_true[tile_bt, tile_v]

if log_target:
# KL(P || Q) = exp(y_true) * (y_true - y_pred) when both in log-space
prob_true = torch.exp(y_true_val)
kl_loss[tile_bt, tile_v] = prob_true * (y_true_val - y_pred_val)

else:
# KL(P || Q) = y_true * (log(y_true) - y_pred) when y_pred in log-space
log_true = torch.log(torch.clamp(y_true_val, min=eps))
kl_loss[tile_bt, tile_v] = y_true_val * (log_true - y_pred_val)

if reduction == "none":
loss[tile_bt, tile_v] = kl_loss[tile_bt, tile_v]
else:
# Sum over vocabulary dimension
loss_sum += kl_loss[tile_bt, tile_v]

if reduction != "none":
loss[tile_bt] = loss_sum.sum(dim=-1)

# Apply final reduction
if reduction == "batchmean":
final_loss = torch.sum(loss) / BT
elif reduction == "sum":
final_loss = torch.sum(loss, dim=0)
elif reduction == "mean":
final_loss = torch.sum(loss) / (BT * V)
else: # reduction == "none"
final_loss = loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we test all reductions in the unit tests / main function? or maybe simplify it to just one case that's being used by tritonbench?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tritonbench only tests the reduction = "batchmean" case: https://github.com/meta-pytorch/tritonbench/blob/main/tritonbench/operators/kl_div/operator.py#L28
while liger_kernel implements all cases. I mirrored liger_kernel implementation: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/kl_div.py#L150

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you suggest I add all unit tests for all reductions, or only implement reduction = "batchmean" ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let's just keep it as-is now.


return final_loss


# %%
# KL Divergence Loss Module
# -------------------------
class HelionKLDivLoss(nn.Module):
"""
Helion implementation of KL Divergence Loss matching PyTorch's KLDivLoss.

KL(P || Q) computes the divergence between target distribution P and input Q.

Args:
reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean')
log_target: If True, target is in log-space; if False, target is probabilities
eps: Small value for numerical stability
"""

def __init__(
self,
reduction: str = "batchmean",
log_target: bool = False,
eps: float = 1e-10,
) -> None:
super().__init__()
self.reduction = reduction
self.log_target = log_target
self.eps = eps

def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
"""
Forward pass computing KL divergence loss.

Args:
input_tensor: Input predictions in log-space, shape (BT, V)
target_tensor: Target values (probabilities or log-probabilities), shape (BT, V)

Returns:
KL divergence loss
"""
return kl_div_forward(
input_tensor, target_tensor, self.log_target, self.reduction, self.eps
)


# %%
# Verification Function
# ---------------------
def check_kl_div_kernel(
B: int,
T: int,
V: int,
reduction: str = "batchmean",
log_target: bool = False,
eps: float = 1e-10,
) -> None:
"""
Verify the KL divergence kernel implementation against PyTorch's baseline.

Args:
B: Batch size
T: Sequence length
V: Vocabulary size
reduction: Reduction mode
log_target: Whether target is in log-space
eps: Small value for numerical stability
"""
# Create test tensors following tritonbench pattern
input_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
dim=-1
)

target_tensor = torch.randn(B * T, V, device="cuda").softmax(dim=-1)

# Test forward pass
helion_kl = HelionKLDivLoss(reduction=reduction, log_target=log_target, eps=eps)
torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=log_target).to(
"cuda"
)

def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
return helion_kl(input_tensor, target_tensor)

def baseline_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
return torch_kl_div(input_tensor, target_tensor)

run_example(helion_wrapper, baseline_wrapper, (input_tensor, target_tensor))


# %%
# Tritonbench Integration
# -----------------------
def kl_div_tritonbench(
tb_op: object, input_tensor: Tensor, target_tensor: Tensor
) -> Callable:
"""
Wrapper for tritonbench that matches its interface.

Args:
tb_op: Tritonbench operator object
input_tensor: Input predictions in log-space
target_tensor: Target values

Returns:
Callable: A callable that runs the KL divergence kernel
"""
helion_kl = HelionKLDivLoss(
reduction="batchmean",
log_target=False, # tritonbench uses probabilities, not log-probabilities
eps=1e-10,
)

return lambda: helion_kl(input_tensor, target_tensor)


# %%
# Main Function
# -------------
def main() -> None:
"""
Main entry point that runs KL divergence kernel verification.
Tests various configurations matching tritonbench settings.
"""
print("Testing KL divergence kernel...")
B = 8
T = 512
reduction = "batchmean"
log_target = False
eps = 1e-10

# Test with vocabulary sizes from tritonbench (2^12 to 2^17)
for V in [2**i for i in range(12, 18)]:
print(
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
)
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
print("✓ KL Div passed")


# %%
if __name__ == "__main__":
main()
86 changes: 86 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,92 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
final_loss = torch.sum(loss)
return (final_loss, dX)

--- assertExpectedJournal(TestExamples.test_kl_div)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kl_div_forward(y_pred, y_true, kl_loss, loss, kl_loss_stride_0, kl_loss_stride_1, loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < BT
loss_sum = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
for offset_0 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < V
loss_sum_copy = loss_sum
loss_sum_copy_0 = loss_sum_copy
y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
if log_target:
y_true_val_copy = y_true_val
y_pred_val_copy = y_pred_val
y_true_val_copy_0 = y_true_val_copy
y_pred_val_copy_0 = y_pred_val_copy
v_0 = libdevice.exp(y_true_val_copy_0)
v_1 = y_true_val_copy_0 - y_pred_val_copy_0
v_2 = v_0 * v_1
tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_2, mask_1[:, None] & mask_0[None, :])
_not = not log_target
if _not:
y_true_val_copy_1 = y_true_val
y_pred_val_copy_1 = y_pred_val
y_true_val_copy_1_0 = y_true_val_copy_1
y_pred_val_copy_1_0 = y_pred_val_copy_1
v_3 = triton_helpers.maximum(y_true_val_copy_1_0, eps)
v_4 = tl_math.log(v_3)
v_5 = v_4 - y_pred_val_copy_1_0
v_6 = y_true_val_copy_1_0 * v_5
tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_6, mask_1[:, None] & mask_0[None, :])
load_2 = tl.load(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
loss_sum = loss_sum_copy_0 + load_2
sum_1 = tl.cast(tl.sum(loss_sum, 1), tl.float32)
tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)

def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduction: str='batchmean', eps: float=1e-10, *, _launcher=_default_launcher):
"""
Compute KL Divergence loss.

Args:
y_pred: Input predictions in log-space, shape (BT, V)
y_true: Target values (probabilities or log-probabilities), shape (BT, V)
log_target: If True, y_true is in log-space; if False, y_true is probabilities
reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean')
eps: Small value to avoid numerical issues

Returns:
loss: KL divergence loss
"""
BT, V = y_pred.shape
assert y_true.shape == y_pred.shape, f'Shape mismatch: {y_true.shape} != {y_pred.shape}'
if reduction == 'none':
loss = torch.zeros_like(y_pred)
else:
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
kl_loss = torch.zeros_like(y_pred)
BT_SIZE = helion.cdiv(BT, BT)
_BLOCK_SIZE_1 = BT_SIZE
_BLOCK_SIZE_0 = 4096
_launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, kl_loss, loss, kl_loss.stride(0), kl_loss.stride(1), loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
if reduction == 'batchmean':
final_loss = torch.sum(loss) / BT
elif reduction == 'sum':
final_loss = torch.sum(loss, dim=0)
elif reduction == 'mean':
final_loss = torch.sum(loss) / (BT * V)
else:
final_loss = loss
return final_loss

--- assertExpectedJournal(TestExamples.test_layernorm_bwd_dwdb)
from __future__ import annotations

Expand Down
Loading
Loading