|
| 1 | +""" |
| 2 | +Helion JSD (Jensen-Shannon Divergence) Example |
| 3 | +============================================== |
| 4 | +This example demonstrates a Helion kernel implementation of Jensen-Shannon Divergence. |
| 5 | +JSD is commonly used in knowledge distillation for language models, where: |
| 6 | +
|
| 7 | +JSD(beta)(P || Q) = beta * KL(P || M) + (1-beta) * KL(Q || M) |
| 8 | +where M = beta * P + (1-beta) * Q is the mixture distribution |
| 9 | +
|
| 10 | +The generalized JSD reduces to: |
| 11 | +- Forward KL when beta = 0: KL(P || Q) |
| 12 | +- Reverse KL when beta = 1: KL(Q || P) |
| 13 | +- Symmetric JSD when beta = 0.5 |
| 14 | +
|
| 15 | +Based on liger_kernel's JSD implementation used for knowledge distillation in language models. |
| 16 | +""" |
| 17 | + |
| 18 | +# %% |
| 19 | +# Imports |
| 20 | +# ------- |
| 21 | +from __future__ import annotations |
| 22 | + |
| 23 | +from typing import TYPE_CHECKING |
| 24 | + |
| 25 | +import torch |
| 26 | +from torch import Tensor |
| 27 | +import torch.nn as nn |
| 28 | + |
| 29 | +import helion |
| 30 | +from helion._testing import run_example |
| 31 | +import helion.language as hl |
| 32 | + |
| 33 | +if TYPE_CHECKING: |
| 34 | + from collections.abc import Callable |
| 35 | + |
| 36 | + |
| 37 | +# %% |
| 38 | +# JSD Kernel |
| 39 | +# ---------- |
| 40 | +@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) |
| 41 | +def jsd_forward( |
| 42 | + _input: Tensor, # student predictions (input) in log-space |
| 43 | + target: Tensor, # teacher targets in log-space |
| 44 | + shift_labels: Tensor | None = None, |
| 45 | + beta: float = 0.5, |
| 46 | + ignore_index: int = -100, |
| 47 | +) -> tuple[Tensor, Tensor]: |
| 48 | + """ |
| 49 | + Compute Jensen-Shannon Divergence loss. |
| 50 | +
|
| 51 | + Args: |
| 52 | + _input: Student predictions in log-space, shape (BT, V) |
| 53 | + target: Teacher targets in log-space, shape (BT, V) |
| 54 | + shift_labels: Optional labels for masking, shape (BT,) |
| 55 | + beta: Coefficient for generalized JSD in [0, 1] |
| 56 | + ignore_index: Index to ignore in labels |
| 57 | +
|
| 58 | + Returns: |
| 59 | + loss: Scalar JSD loss |
| 60 | + dX: Gradient of loss wrt input |
| 61 | + """ |
| 62 | + BT, V = _input.shape |
| 63 | + assert target.shape == _input.shape, ( |
| 64 | + f"Shape mismatch: {target.shape} != {_input.shape}" |
| 65 | + ) |
| 66 | + n_rows = BT |
| 67 | + |
| 68 | + # Create output tensor for accumulating loss |
| 69 | + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) |
| 70 | + dX = torch.empty_like(_input) |
| 71 | + |
| 72 | + # Count non-ignored elements |
| 73 | + n_non_ignore = float(BT) |
| 74 | + if shift_labels is not None: |
| 75 | + n_non_ignore = float((shift_labels != ignore_index).sum().item()) |
| 76 | + if n_non_ignore == 0: |
| 77 | + return torch.zeros( |
| 78 | + [], dtype=_input.dtype, device=_input.device |
| 79 | + ), torch.zeros_like(_input) |
| 80 | + |
| 81 | + # Process each sequence position |
| 82 | + BT_SIZE = helion.cdiv(BT, n_rows) # The liger kernel uses 1 |
| 83 | + for tile_bt in hl.tile(BT, block_size=BT_SIZE): |
| 84 | + # Check for label masking |
| 85 | + if shift_labels is not None: |
| 86 | + if shift_labels[tile_bt] == ignore_index: |
| 87 | + for tile_X in hl.tile(V): |
| 88 | + dX[tile_bt, tile_X] = 0.0 |
| 89 | + continue |
| 90 | + |
| 91 | + for tile_v in hl.tile(V): |
| 92 | + # Load log probabilities and convert to float32 |
| 93 | + X = _input[tile_bt, tile_v] |
| 94 | + Y = target[tile_bt, tile_v] |
| 95 | + X_max = torch.amax(X, dim=0) |
| 96 | + Y_max = torch.amax(Y, dim=0) |
| 97 | + |
| 98 | + if beta == 0.0: # Forward KL: KL(P || Q) |
| 99 | + Y_shift = Y - Y_max |
| 100 | + Y_prob = torch.exp(Y_shift) * torch.exp( |
| 101 | + Y_max |
| 102 | + ) # Compensate for the shift |
| 103 | + loss[tile_bt, tile_v] = Y_prob * (Y - X) |
| 104 | + dX[tile_bt, tile_v] = -Y_prob |
| 105 | + elif beta == 1.0: # Reverse KL: KL(Q || P) |
| 106 | + X_shift = X - X_max |
| 107 | + X_prob = torch.exp(X_shift) * torch.exp( |
| 108 | + X_max |
| 109 | + ) # Compensate for the shift |
| 110 | + loss[tile_bt, tile_v] = X_prob * (X - Y) |
| 111 | + dX[tile_bt, tile_v] = loss[tile_bt, tile_v] + X_prob |
| 112 | + else: # General JSD: beta*KL(P||M) + (1-beta)*KL(Q||M) |
| 113 | + max_val = torch.maximum(X_max, Y_max) |
| 114 | + X_shifted = X - max_val |
| 115 | + Y_shifted = Y - max_val |
| 116 | + |
| 117 | + exp_max = torch.exp(max_val) |
| 118 | + |
| 119 | + Q = torch.exp(X_shifted) * exp_max # = exp(X) |
| 120 | + P = torch.exp(Y_shifted) * exp_max # = exp(Y) |
| 121 | + |
| 122 | + beta_P = beta * P |
| 123 | + one_minus_beta_Q = (1 - beta) * Q |
| 124 | + M = beta_P + one_minus_beta_Q |
| 125 | + log_M = torch.log( |
| 126 | + M |
| 127 | + ) # No need to compensate as M is already in original scale |
| 128 | + |
| 129 | + loss[tile_bt, tile_v] = beta_P * Y + one_minus_beta_Q * X - M * log_M |
| 130 | + dX[tile_bt, tile_v] = one_minus_beta_Q * (X - log_M) |
| 131 | + |
| 132 | + # Accumulate over vocabulary dimension |
| 133 | + scale = 1.0 / n_non_ignore |
| 134 | + loss[tile_bt, tile_v] = loss[tile_bt, tile_v] * scale |
| 135 | + dX[tile_bt, tile_v] = dX[tile_bt, tile_v] * scale |
| 136 | + |
| 137 | + # Normalize by number of non-ignored elements, run it on host to match liger_kernel |
| 138 | + final_loss = torch.sum( |
| 139 | + loss |
| 140 | + ) # This line raises a warning: helion.exc.TensorOperationInWrapper |
| 141 | + |
| 142 | + return final_loss, dX |
| 143 | + |
| 144 | + |
| 145 | +# %% |
| 146 | +# JSD Loss Module (matches liger_kernel structure) |
| 147 | +# ------------------------------------------------ |
| 148 | +class HelionJSD(nn.Module): |
| 149 | + """ |
| 150 | + Helion implementation of Jensen-Shannon Divergence matching liger_kernel.LigerJSD structure. |
| 151 | +
|
| 152 | + JSD(beta)(P || Q) = beta * KL(P || M) + (1-beta) * KL(Q || M) |
| 153 | + where M = beta * P + (1-beta) * Q |
| 154 | +
|
| 155 | + Args: |
| 156 | + beta: Coefficient beta ∈ [0,1]. When beta=0: forward KL, beta=1: reverse KL, beta=0.5: symmetric JSD |
| 157 | + ignore_index: Index to ignore in labels for masking |
| 158 | + dtype: Data type for loss computation |
| 159 | + """ |
| 160 | + |
| 161 | + def __init__( |
| 162 | + self, |
| 163 | + beta: float = 0.5, |
| 164 | + ignore_index: int = -100, |
| 165 | + dtype: torch.dtype = torch.float, |
| 166 | + ) -> None: |
| 167 | + super().__init__() |
| 168 | + self.beta = beta |
| 169 | + self.ignore_index = ignore_index |
| 170 | + self.dtype = dtype |
| 171 | + |
| 172 | + def forward( |
| 173 | + self, |
| 174 | + _input: Tensor, # student predictions in log-space |
| 175 | + target: Tensor, # teacher targets in log-space |
| 176 | + shift_labels: Tensor | None = None, |
| 177 | + ) -> Tensor: |
| 178 | + """ |
| 179 | + Forward pass computing JSD loss. |
| 180 | +
|
| 181 | + Args: |
| 182 | + _input: Student predictions in log-space, shape (BT, V) |
| 183 | + target: Teacher targets in log-space, shape (BT, V) |
| 184 | + shift_labels: Optional labels for masking, shape (BT,) |
| 185 | + Returns: |
| 186 | + Scalar JSD loss |
| 187 | + """ |
| 188 | + if shift_labels is not None: |
| 189 | + assert shift_labels.shape == (_input.shape[0],), ( |
| 190 | + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" |
| 191 | + ) |
| 192 | + shift_labels = shift_labels.contiguous() |
| 193 | + loss, dX = jsd_forward( |
| 194 | + _input, target, shift_labels, self.beta, self.ignore_index |
| 195 | + ) |
| 196 | + return loss.to(self.dtype) |
| 197 | + |
| 198 | + |
| 199 | +class TorchJSDBaseline(nn.Module): |
| 200 | + """PyTorch baseline JSD implementation matching tritonbench.""" |
| 201 | + |
| 202 | + def __init__( |
| 203 | + self, |
| 204 | + beta: float = 0.5, |
| 205 | + ignore_index: int = -100, |
| 206 | + dtype: torch.dtype = torch.float, |
| 207 | + ) -> None: |
| 208 | + super().__init__() |
| 209 | + self.kl = nn.KLDivLoss(reduction="none", log_target=True) |
| 210 | + self.beta = beta |
| 211 | + self.ignore_index = ignore_index |
| 212 | + self.dtype = dtype |
| 213 | + |
| 214 | + def forward( |
| 215 | + self, log_q: Tensor, log_p: Tensor, label: Tensor | None = None |
| 216 | + ) -> Tensor: |
| 217 | + # Convert to float for computation |
| 218 | + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) |
| 219 | + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) |
| 220 | + |
| 221 | + # Mixture distribution |
| 222 | + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) |
| 223 | + |
| 224 | + # JSD loss |
| 225 | + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( |
| 226 | + 1 - self.beta |
| 227 | + ) * self.kl(torch.log(m), log_q).sum(dim=-1) |
| 228 | + |
| 229 | + if label is not None: |
| 230 | + loss = torch.where(label != self.ignore_index, loss, 0.0) |
| 231 | + n_non_ignore = (label != self.ignore_index).sum().item() |
| 232 | + if n_non_ignore == 0: |
| 233 | + loss = torch.tensor(0.0, device=log_q.device, dtype=torch.float) |
| 234 | + else: |
| 235 | + loss = (loss / n_non_ignore).sum() |
| 236 | + else: |
| 237 | + loss = (loss / log_q.shape[0]).sum() |
| 238 | + |
| 239 | + return loss.to(self.dtype) |
| 240 | + |
| 241 | + |
| 242 | +# %% |
| 243 | +# Verification Function |
| 244 | +# --------------------- |
| 245 | +def check_jsd_kernel( |
| 246 | + B: int, |
| 247 | + T: int, |
| 248 | + V: int, |
| 249 | + beta: float = 0.5, |
| 250 | + ignore_index: int = -100, |
| 251 | + use_labels: bool = False, |
| 252 | +) -> None: |
| 253 | + """ |
| 254 | + Verify the JSD kernel implementation against PyTorch's baseline. |
| 255 | +
|
| 256 | + Args: |
| 257 | + B: Batch size (B) |
| 258 | + T: Sequence length (T) |
| 259 | + V: Vocabulary size (V) |
| 260 | + beta: JSD coefficient |
| 261 | + ignore_index: Index to ignore in labels |
| 262 | + use_labels: Whether to test with label masking |
| 263 | + """ |
| 264 | + # Create test tensors |
| 265 | + log_q = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(dim=-1) |
| 266 | + log_p = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) |
| 267 | + |
| 268 | + shift_labels = None |
| 269 | + if use_labels: |
| 270 | + shift_labels = torch.randint(0, V, (B,), device="cuda") |
| 271 | + # Randomly set some to ignore_index |
| 272 | + shift_labels[torch.rand(B, device="cuda") < 0.1] = -100 |
| 273 | + |
| 274 | + # Test forward pass only (no gradients for now) |
| 275 | + helion_jsd = HelionJSD(beta=beta, ignore_index=ignore_index) |
| 276 | + torch_jsd = TorchJSDBaseline(beta=beta, ignore_index=ignore_index) |
| 277 | + |
| 278 | + def helion_wrapper( |
| 279 | + log_q: Tensor, log_p: Tensor, shift_labels: Tensor | None = None |
| 280 | + ) -> Tensor: |
| 281 | + return helion_jsd(log_q, log_p, shift_labels) |
| 282 | + |
| 283 | + def baseline_wrapper( |
| 284 | + log_q: Tensor, log_p: Tensor, shift_labels: Tensor | None = None |
| 285 | + ) -> Tensor: |
| 286 | + return torch_jsd(log_q, log_p, shift_labels) |
| 287 | + |
| 288 | + run_example(helion_wrapper, baseline_wrapper, (log_q, log_p, shift_labels)) |
| 289 | + |
| 290 | + |
| 291 | +# %% |
| 292 | +# Tritonbench Integration |
| 293 | +# ----------------------- |
| 294 | +def jsd_tritonbench(tb_op: object, log_q: Tensor, log_p: Tensor) -> Callable: |
| 295 | + """ |
| 296 | + Wrapper for tritonbench that matches its interface. |
| 297 | +
|
| 298 | + Args: |
| 299 | + log_q: Student predictions in log-space |
| 300 | + log_p: Teacher targets in log-space |
| 301 | +
|
| 302 | + Returns: |
| 303 | + Callable: A callable that runs the JSD kernel |
| 304 | + """ |
| 305 | + |
| 306 | + baseline_model = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue] |
| 307 | + |
| 308 | + helion_jsd = HelionJSD( |
| 309 | + beta=baseline_model.beta, |
| 310 | + ignore_index=baseline_model.ignore_index, |
| 311 | + dtype=baseline_model.dtype, |
| 312 | + ) |
| 313 | + |
| 314 | + return lambda: helion_jsd(log_q, log_p) |
| 315 | + |
| 316 | + |
| 317 | +# %% |
| 318 | +# Main Function |
| 319 | +# ------------- |
| 320 | +def main() -> None: |
| 321 | + """ |
| 322 | + Main entry point that runs JSD kernel verification. |
| 323 | + Tests various configurations including different beta values and label masking. |
| 324 | + """ |
| 325 | + print("Testing JSD kernel...") |
| 326 | + B = 4 |
| 327 | + T = 2048 |
| 328 | + beta = 0.5 |
| 329 | + ignore_index = -100 |
| 330 | + use_labels = False |
| 331 | + |
| 332 | + for V in [2**i for i in range(12, 18)]: |
| 333 | + print( |
| 334 | + f"Testing JSD: B={B}, T={T}, V={V}, beta={beta}, ignore_index={ignore_index}, labels={use_labels}" |
| 335 | + ) |
| 336 | + check_jsd_kernel(B, T, V, beta, ignore_index, use_labels) |
| 337 | + print("✓ JSD passed") |
| 338 | + |
| 339 | + |
| 340 | +# %% |
| 341 | +if __name__ == "__main__": |
| 342 | + main() |
0 commit comments