|
| 1 | +""" |
| 2 | +INT4 General Matrix Multiplication (GEMM) with Helion |
| 3 | +===================================================== |
| 4 | +This example demonstrates an INT4 GEMM kernel implemented in Helion. The kernel performs |
| 5 | +matrix multiplication where the second matrix B is packed with two 4-bit values per byte. |
| 6 | +The kernel unpacks the int4 values, converts to bfloat16, and performs matmul with |
| 7 | +the bfloat16 matrix A. |
| 8 | +""" |
| 9 | + |
| 10 | +# %% |
| 11 | +# Imports |
| 12 | +# ------- |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +from typing import Callable |
| 16 | + |
| 17 | +import torch |
| 18 | +from torch import Tensor |
| 19 | + |
| 20 | +import helion |
| 21 | +import helion.language as hl |
| 22 | + |
| 23 | + |
| 24 | +# %% |
| 25 | +# INT4 GEMM Kernel |
| 26 | +# ---------------- |
| 27 | +@helion.kernel( |
| 28 | + use_default_config=True, |
| 29 | + static_shapes=False, # Allow dynamic shapes to handle different input sizes |
| 30 | +) |
| 31 | +def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor: |
| 32 | + """ |
| 33 | + BFloat16 x INT4 General Matrix Multiplication (GEMM). |
| 34 | +
|
| 35 | + This kernel performs matrix multiplication where: |
| 36 | + - A is a bfloat16 matrix of shape [M, K] |
| 37 | + - B is an int8 matrix of shape [K//2, N] containing packed int4 values |
| 38 | + (two 4-bit values packed into each int8) |
| 39 | +
|
| 40 | + Args: |
| 41 | + A (Tensor): Input tensor of shape [M, K] in bfloat16 format. |
| 42 | + B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format. |
| 43 | +
|
| 44 | + Returns: |
| 45 | + Tensor: Output tensor of shape [M, N] in bfloat16 format. |
| 46 | + """ |
| 47 | + M, K = A.shape |
| 48 | + _, N = B.shape |
| 49 | + |
| 50 | + C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) |
| 51 | + block_size_k_packed = hl.register_block_size(K // 2) |
| 52 | + |
| 53 | + # Use Helion to tile the computation |
| 54 | + for tile_m, tile_n in hl.tile([M, N]): |
| 55 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 56 | + |
| 57 | + for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed): |
| 58 | + # Load packed int8 data from B |
| 59 | + b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] |
| 60 | + |
| 61 | + # Extract low and high 4-bit values with sign extension |
| 62 | + # Low nibble: sign-extend from 4-bit to 8-bit using left shift then arithmetic right shift |
| 63 | + b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits |
| 64 | + b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits |
| 65 | + |
| 66 | + # Convert to bfloat16 |
| 67 | + b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] |
| 68 | + b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] |
| 69 | + |
| 70 | + # Stack and reshape to interleave low and high bits |
| 71 | + # Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] |
| 72 | + b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1) |
| 73 | + |
| 74 | + # Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N] |
| 75 | + # This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ... |
| 76 | + b_unpacked = b_stacked.reshape( |
| 77 | + tile_k_packed.block_size * 2, tile_n.block_size |
| 78 | + ) |
| 79 | + |
| 80 | + # Load corresponding tiles from A (need to load twice the packed tile size) |
| 81 | + # We need to map tile_k_packed to the corresponding range in A |
| 82 | + a_tile_begin = tile_k_packed.begin * 2 |
| 83 | + a_tile_len = tile_k_packed.block_size * 2 |
| 84 | + a_tile = A[ |
| 85 | + tile_m, a_tile_begin : (a_tile_begin + a_tile_len) |
| 86 | + ] # [BLOCK_SIZE_M, BLOCK_SIZE_K] |
| 87 | + |
| 88 | + acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N] |
| 89 | + |
| 90 | + C[tile_m, tile_n] = acc.to(torch.bfloat16) |
| 91 | + |
| 92 | + return C |
| 93 | + |
| 94 | + |
| 95 | +# %% |
| 96 | +# TritonBench Wrapper |
| 97 | +# ------------------- |
| 98 | +def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Callable: |
| 99 | + """ |
| 100 | + Wrapper for TritonBench compatibility. |
| 101 | +
|
| 102 | + Args: |
| 103 | + tb_op: TritonBench operator instance |
| 104 | + x (torch.Tensor): Left input tensor in bfloat16 format. |
| 105 | + w (torch.Tensor): Right input tensor of shape [K, N] containing int4 values. |
| 106 | + Will be packed to int4 format. |
| 107 | +
|
| 108 | + Returns: |
| 109 | + Callable: A function that performs the int4 gemm. |
| 110 | + """ |
| 111 | + |
| 112 | + def run_kernel() -> torch.Tensor: |
| 113 | + x_2d = x.reshape(-1, x.size(-1)) |
| 114 | + |
| 115 | + # Pack w to int4 format (two 4-bit values per int8 byte) |
| 116 | + w_int8 = w.to(torch.int8) |
| 117 | + w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2) |
| 118 | + w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8) |
| 119 | + |
| 120 | + return matmul_bf16_int4(x_2d, w_packed) |
| 121 | + |
| 122 | + return run_kernel |
| 123 | + |
| 124 | + |
| 125 | +# %% |
| 126 | +# Verification Function |
| 127 | +# --------------------- |
| 128 | +def check(m: int, k: int, n: int) -> None: |
| 129 | + """ |
| 130 | + Test the INT4 GEMM implementation. |
| 131 | +
|
| 132 | + Args: |
| 133 | + m (int): Number of rows in the left input matrix. |
| 134 | + k (int): Shared dimension (must be even). |
| 135 | + n (int): Number of columns in the right input matrix. |
| 136 | + """ |
| 137 | + # Create test matrices |
| 138 | + A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") |
| 139 | + |
| 140 | + # Create packed int4 matrix B (K//2 x N) |
| 141 | + # Generate random int4 values in range [-8, 7] and pack them |
| 142 | + B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device="cuda") |
| 143 | + |
| 144 | + # Pack using the same format as tritonbench |
| 145 | + B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2) |
| 146 | + B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8) |
| 147 | + |
| 148 | + # Convert unpacked values to bfloat16 for reference |
| 149 | + B_unpacked_bf16 = B_unpacked.to(torch.bfloat16) |
| 150 | + |
| 151 | + # Compute reference result |
| 152 | + expected = torch.matmul(A, B_unpacked_bf16) |
| 153 | + |
| 154 | + # Run the kernel |
| 155 | + result = matmul_bf16_int4(A, B_packed) |
| 156 | + |
| 157 | + # Check accuracy with appropriate tolerance |
| 158 | + torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0) |
| 159 | + print(f"Test passed for shapes: M={m}, K={k}, N={n}") |
| 160 | + |
| 161 | + |
| 162 | +# %% |
| 163 | +# Main Function |
| 164 | +# ------------- |
| 165 | +def main() -> None: |
| 166 | + """ |
| 167 | + Main function to run tests with different matrix sizes. |
| 168 | + """ |
| 169 | + check(256, 512, 256) |
| 170 | + check(512, 512, 512) |
| 171 | + check(1024, 1024, 1024) |
| 172 | + |
| 173 | + |
| 174 | +# %% |
| 175 | +# Run Example |
| 176 | +# ----------- |
| 177 | +if __name__ == "__main__": |
| 178 | + main() |
0 commit comments