|
| 1 | +""" |
| 2 | +Helion KL Divergence Example |
| 3 | +============================ |
| 4 | +This example demonstrates a Helion kernel implementation of Kullback-Leibler Divergence. |
| 5 | +KL divergence is commonly used in deep learning for comparing probability distributions: |
| 6 | +
|
| 7 | +KL(P || Q) = sum_i P(i) * log(P(i) / Q(i)) |
| 8 | +
|
| 9 | +When the input is in log-space (as common with log-softmax outputs): |
| 10 | +KL(P || Q) = sum_i P(i) * (log(P(i)) - log(Q(i))) |
| 11 | +
|
| 12 | +The loss supports different reduction modes: |
| 13 | +- 'none': No reduction, returns per-example losses |
| 14 | +- 'sum': Sum all losses |
| 15 | +- 'mean': Average over all elements |
| 16 | +- 'batchmean': Average over batch dimension |
| 17 | +
|
| 18 | +Based on liger_kernel's KL divergence implementation used in language models. |
| 19 | +""" |
| 20 | + |
| 21 | +# %% |
| 22 | +# Imports |
| 23 | +# ------- |
| 24 | +from __future__ import annotations |
| 25 | + |
| 26 | +from typing import TYPE_CHECKING |
| 27 | + |
| 28 | +import torch |
| 29 | +from torch import Tensor |
| 30 | +import torch.nn as nn |
| 31 | + |
| 32 | +import helion |
| 33 | +from helion._testing import run_example |
| 34 | +import helion.language as hl |
| 35 | + |
| 36 | +if TYPE_CHECKING: |
| 37 | + from collections.abc import Callable |
| 38 | + |
| 39 | + |
| 40 | +# %% |
| 41 | +# KL Divergence Kernel |
| 42 | +# -------------------- |
| 43 | +@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) |
| 44 | +def kl_div_forward( |
| 45 | + y_pred: Tensor, # input predictions in log-space, shape (BT, V) |
| 46 | + y_true: Tensor, # target values, shape (BT, V) |
| 47 | + log_target: bool = False, |
| 48 | + reduction: str = "batchmean", |
| 49 | + eps: float = 1e-10, |
| 50 | +) -> Tensor: |
| 51 | + """ |
| 52 | + Compute KL Divergence loss. |
| 53 | +
|
| 54 | + Args: |
| 55 | + y_pred: Input predictions in log-space, shape (BT, V) |
| 56 | + y_true: Target values (probabilities or log-probabilities), shape (BT, V) |
| 57 | + log_target: If True, y_true is in log-space; if False, y_true is probabilities |
| 58 | + reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean') |
| 59 | + eps: Small value to avoid numerical issues |
| 60 | +
|
| 61 | + Returns: |
| 62 | + loss: KL divergence loss |
| 63 | + """ |
| 64 | + BT, V = y_pred.shape |
| 65 | + assert y_true.shape == y_pred.shape, ( |
| 66 | + f"Shape mismatch: {y_true.shape} != {y_pred.shape}" |
| 67 | + ) |
| 68 | + |
| 69 | + # Initialize loss accumulator |
| 70 | + if reduction == "none": |
| 71 | + loss = torch.zeros_like(y_pred) |
| 72 | + else: |
| 73 | + loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device) |
| 74 | + |
| 75 | + kl_loss = torch.zeros_like(y_pred) |
| 76 | + |
| 77 | + # Call register_block_size to know block_size_n outside of the reduction loop. |
| 78 | + block_size_n = hl.register_block_size(V) |
| 79 | + |
| 80 | + BT_SIZE = helion.cdiv(BT, BT) # Process all at once for simplicity |
| 81 | + for tile_bt in hl.tile(BT, block_size=BT_SIZE): |
| 82 | + loss_sum = hl.zeros([tile_bt, block_size_n], dtype=torch.float32) |
| 83 | + |
| 84 | + for tile_v in hl.tile(V, block_size=block_size_n): |
| 85 | + y_pred_val = y_pred[tile_bt, tile_v] |
| 86 | + y_true_val = y_true[tile_bt, tile_v] |
| 87 | + |
| 88 | + if log_target: |
| 89 | + # KL(P || Q) = exp(y_true) * (y_true - y_pred) when both in log-space |
| 90 | + prob_true = torch.exp(y_true_val) |
| 91 | + kl_loss[tile_bt, tile_v] = prob_true * (y_true_val - y_pred_val) |
| 92 | + |
| 93 | + else: |
| 94 | + # KL(P || Q) = y_true * (log(y_true) - y_pred) when y_pred in log-space |
| 95 | + log_true = torch.log(torch.clamp(y_true_val, min=eps)) |
| 96 | + kl_loss[tile_bt, tile_v] = y_true_val * (log_true - y_pred_val) |
| 97 | + |
| 98 | + if reduction == "none": |
| 99 | + loss[tile_bt, tile_v] = kl_loss[tile_bt, tile_v] |
| 100 | + else: |
| 101 | + # Sum over vocabulary dimension |
| 102 | + loss_sum += kl_loss[tile_bt, tile_v] |
| 103 | + |
| 104 | + if reduction != "none": |
| 105 | + loss[tile_bt] = loss_sum.sum(dim=-1) |
| 106 | + |
| 107 | + # Apply final reduction |
| 108 | + if reduction == "batchmean": |
| 109 | + final_loss = torch.sum(loss) / BT |
| 110 | + elif reduction == "sum": |
| 111 | + final_loss = torch.sum(loss, dim=0) |
| 112 | + elif reduction == "mean": |
| 113 | + final_loss = torch.sum(loss) / (BT * V) |
| 114 | + else: # reduction == "none" |
| 115 | + final_loss = loss |
| 116 | + |
| 117 | + return final_loss |
| 118 | + |
| 119 | + |
| 120 | +# %% |
| 121 | +# KL Divergence Loss Module |
| 122 | +# ------------------------- |
| 123 | +class HelionKLDivLoss(nn.Module): |
| 124 | + """ |
| 125 | + Helion implementation of KL Divergence Loss matching PyTorch's KLDivLoss. |
| 126 | +
|
| 127 | + KL(P || Q) computes the divergence between target distribution P and input Q. |
| 128 | +
|
| 129 | + Args: |
| 130 | + reduction: Reduction mode ('none', 'sum', 'mean', 'batchmean') |
| 131 | + log_target: If True, target is in log-space; if False, target is probabilities |
| 132 | + eps: Small value for numerical stability |
| 133 | + """ |
| 134 | + |
| 135 | + def __init__( |
| 136 | + self, |
| 137 | + reduction: str = "batchmean", |
| 138 | + log_target: bool = False, |
| 139 | + eps: float = 1e-10, |
| 140 | + ) -> None: |
| 141 | + super().__init__() |
| 142 | + self.reduction = reduction |
| 143 | + self.log_target = log_target |
| 144 | + self.eps = eps |
| 145 | + |
| 146 | + def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor: |
| 147 | + """ |
| 148 | + Forward pass computing KL divergence loss. |
| 149 | +
|
| 150 | + Args: |
| 151 | + input_tensor: Input predictions in log-space, shape (BT, V) |
| 152 | + target_tensor: Target values (probabilities or log-probabilities), shape (BT, V) |
| 153 | +
|
| 154 | + Returns: |
| 155 | + KL divergence loss |
| 156 | + """ |
| 157 | + return kl_div_forward( |
| 158 | + input_tensor, target_tensor, self.log_target, self.reduction, self.eps |
| 159 | + ) |
| 160 | + |
| 161 | + |
| 162 | +# %% |
| 163 | +# Verification Function |
| 164 | +# --------------------- |
| 165 | +def check_kl_div_kernel( |
| 166 | + B: int, |
| 167 | + T: int, |
| 168 | + V: int, |
| 169 | + reduction: str = "batchmean", |
| 170 | + log_target: bool = False, |
| 171 | + eps: float = 1e-10, |
| 172 | +) -> None: |
| 173 | + """ |
| 174 | + Verify the KL divergence kernel implementation against PyTorch's baseline. |
| 175 | +
|
| 176 | + Args: |
| 177 | + B: Batch size |
| 178 | + T: Sequence length |
| 179 | + V: Vocabulary size |
| 180 | + reduction: Reduction mode |
| 181 | + log_target: Whether target is in log-space |
| 182 | + eps: Small value for numerical stability |
| 183 | + """ |
| 184 | + # Create test tensors following tritonbench pattern |
| 185 | + input_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( |
| 186 | + dim=-1 |
| 187 | + ) |
| 188 | + |
| 189 | + target_tensor = torch.randn(B * T, V, device="cuda").softmax(dim=-1) |
| 190 | + |
| 191 | + # Test forward pass |
| 192 | + helion_kl = HelionKLDivLoss(reduction=reduction, log_target=log_target, eps=eps) |
| 193 | + torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=log_target).to( |
| 194 | + "cuda" |
| 195 | + ) |
| 196 | + |
| 197 | + def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor: |
| 198 | + return helion_kl(input_tensor, target_tensor) |
| 199 | + |
| 200 | + def baseline_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor: |
| 201 | + return torch_kl_div(input_tensor, target_tensor) |
| 202 | + |
| 203 | + run_example(helion_wrapper, baseline_wrapper, (input_tensor, target_tensor)) |
| 204 | + |
| 205 | + |
| 206 | +# %% |
| 207 | +# Tritonbench Integration |
| 208 | +# ----------------------- |
| 209 | +def kl_div_tritonbench( |
| 210 | + tb_op: object, input_tensor: Tensor, target_tensor: Tensor |
| 211 | +) -> Callable: |
| 212 | + """ |
| 213 | + Wrapper for tritonbench that matches its interface. |
| 214 | +
|
| 215 | + Args: |
| 216 | + tb_op: Tritonbench operator object |
| 217 | + input_tensor: Input predictions in log-space |
| 218 | + target_tensor: Target values |
| 219 | +
|
| 220 | + Returns: |
| 221 | + Callable: A callable that runs the KL divergence kernel |
| 222 | + """ |
| 223 | + helion_kl = HelionKLDivLoss( |
| 224 | + reduction="batchmean", |
| 225 | + log_target=False, # tritonbench uses probabilities, not log-probabilities |
| 226 | + eps=1e-10, |
| 227 | + ) |
| 228 | + |
| 229 | + return lambda: helion_kl(input_tensor, target_tensor) |
| 230 | + |
| 231 | + |
| 232 | +# %% |
| 233 | +# Main Function |
| 234 | +# ------------- |
| 235 | +def main() -> None: |
| 236 | + """ |
| 237 | + Main entry point that runs KL divergence kernel verification. |
| 238 | + Tests various configurations matching tritonbench settings. |
| 239 | + """ |
| 240 | + print("Testing KL divergence kernel...") |
| 241 | + B = 8 |
| 242 | + T = 512 |
| 243 | + reduction = "batchmean" |
| 244 | + log_target = False |
| 245 | + eps = 1e-10 |
| 246 | + |
| 247 | + # Test with vocabulary sizes from tritonbench (2^12 to 2^17) |
| 248 | + for V in [2**i for i in range(12, 18)]: |
| 249 | + print( |
| 250 | + f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}" |
| 251 | + ) |
| 252 | + check_kl_div_kernel(B, T, V, reduction, log_target, eps) |
| 253 | + print("✓ KL Div passed") |
| 254 | + |
| 255 | + |
| 256 | +# %% |
| 257 | +if __name__ == "__main__": |
| 258 | + main() |
0 commit comments