Skip to content

Commit 67c596e

Browse files
committed
100%|███████████████████████████████████████████████████████████████████████████████████████| 22/22 [01:20<00:00, 3.66s/it]
num_rows num_cols high_precision_dtype low_precision_dtype cuda_time pytorch_time compiled_pytorch_time ---------- ---------- ---------------------- --------------------- ----------- -------------- ----------------------- 512 512 torch.bfloat16 torch.float8_e4m3fn 5.16311 38.2956 17.9364 512 512 torch.bfloat16 torch.float8_e5m2 5.04093 37.7555 17.99 1024 1024 torch.bfloat16 torch.float8_e4m3fn 5.15543 38.1199 58.628 1024 1024 torch.bfloat16 torch.float8_e5m2 5.13135 38.415 58.9279 2048 2048 torch.bfloat16 torch.float8_e4m3fn 6.82004 69.8858 20.4935 2048 2048 torch.bfloat16 torch.float8_e5m2 6.82519 70.1765 20.4055 1024 8192 torch.bfloat16 torch.float8_e4m3fn 11.5463 153.481 20.5978 1024 8192 torch.bfloat16 torch.float8_e5m2 11.5646 153.569 20.5893 8192 1280 torch.bfloat16 torch.float8_e4m3fn 13.8572 209.535 20.4999 8192 1280 torch.bfloat16 torch.float8_e5m2 13.8607 209.533 20.3462 8192 7168 torch.bfloat16 torch.float8_e4m3fn 81.5091 1177.47 85.1974 8192 7168 torch.bfloat16 torch.float8_e5m2 81.7804 1177.95 85.2696 3584 8192 torch.bfloat16 torch.float8_e4m3fn 43.1506 605.872 44.7317 3584 8192 torch.bfloat16 torch.float8_e5m2 43.1619 606.259 44.725 2048 109760 torch.bfloat16 torch.float8_e4m3fn 299.782 4407.9 319.255 2048 109760 torch.bfloat16 torch.float8_e5m2 299.757 4408.04 320.204 1 3232 torch.bfloat16 torch.float8_e4m3fn 5.01986 38.1827 58.037 1 3232 torch.bfloat16 torch.float8_e5m2 5.05898 38.0519 58.6179 2048 1 torch.bfloat16 torch.float8_e4m3fn 5.00403 37.8273 57.656 2048 1 torch.bfloat16 torch.float8_e5m2 5.05094 38.2455 57.776 14144 2048 torch.bfloat16 torch.float8_e4m3fn 42.632 598.726 44.2741 14144 2048 torch.bfloat16 torch.float8_e5m2 42.6492 598.401 44.2601
1 parent e43ede0 commit 67c596e

File tree

2 files changed

+74
-13
lines changed

2 files changed

+74
-13
lines changed

benchmarks/benchmark_saturated_casting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def get_configs() -> List[ExperimentConfig]:
7373
(8192, 1280),
7474
(8192, 7168),
7575
(3584, 8192),
76+
(2048, 109760),
77+
(1, 3232),
78+
(2048, 1),
79+
(14144, 2048),
7680
]
7781
for (rows, cols), high_precision_dtype, low_precision_dtype in itertools.product(
7882
num_rows_cols, high_precision_dtypes, low_precision_dtypes

src/saturated_cast.cu

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,41 @@ __global__ void saturated_cast_kernel_single(
2828
}
2929
}
3030

31-
template<int coarse_factor>
31+
template <int coarse_factor>
32+
__global__ void saturated_cast_kernel_double_coalesced_flat(
33+
nv_bfloat162 const *__restrict input,
34+
__nv_fp8x2_storage_t *__restrict output, const int numels,
35+
__nv_fp8_interpretation_t out_dtype, nv_bfloat16 const *scaler) {
36+
const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * coarse_factor;
37+
const int stride = 1;
38+
const nv_bfloat162 scale_2 = {(*scaler), (*scaler)};
39+
40+
nv_bfloat162 scaled_inputs[coarse_factor];
41+
#pragma unroll
42+
for (int i{0}; i < coarse_factor; ++i) {
43+
const int temp_idx = idx + i;
44+
if (temp_idx < numels) {
45+
scaled_inputs[i] = input[temp_idx * stride];
46+
}
47+
}
48+
#pragma unroll
49+
for (int i{0}; i < coarse_factor; ++i) {
50+
const int temp_idx = idx + i;
51+
if (temp_idx < numels) {
52+
scaled_inputs[i] = __hmul2(scaled_inputs[i], scale_2);
53+
}
54+
}
55+
#pragma unroll
56+
for (int i{0}; i < coarse_factor; ++i) {
57+
const int temp_idx = idx + i;
58+
if (temp_idx < numels) {
59+
output[temp_idx * stride] = __nv_cvt_bfloat16raw2_to_fp8x2(
60+
scaled_inputs[i], __nv_saturation_t::__NV_SATFINITE, out_dtype);
61+
}
62+
}
63+
}
64+
65+
template <int coarse_factor>
3266
__global__ void saturated_cast_kernel_double_coalesced(
3367
nv_bfloat162 const *__restrict input,
3468
__nv_fp8x2_storage_t *__restrict output, int n_rows, int n_cols,
@@ -59,9 +93,8 @@ __global__ void saturated_cast_kernel_double_coalesced(
5993
const int temp_col = col + i;
6094
if (row < n_rows && temp_col < n_cols) {
6195
output[row * row_stride + temp_col * col_stride] =
62-
__nv_cvt_bfloat16raw2_to_fp8x2(scaled_inputs[i],
63-
__nv_saturation_t::__NV_SATFINITE,
64-
out_dtype);
96+
__nv_cvt_bfloat16raw2_to_fp8x2(
97+
scaled_inputs[i], __nv_saturation_t::__NV_SATFINITE, out_dtype);
6598
}
6699
}
67100
}
@@ -84,8 +117,26 @@ void dispatch_best_kernel(const Tensor &input, const Tensor &output,
84117
const int n_cols = input.size(1);
85118
const int block_size_x = 32;
86119
const int block_size_y = 32;
87-
if (n_cols % 2 == 0) {
88-
// We cast to a 2x8 type, so we need to divide the number of columns by 2
120+
const auto numel = input.numel();
121+
int kernel_choice = 0;
122+
if (numel % 2 == 0 && !transpose) {
123+
kernel_choice = 2;
124+
} else if (n_cols % 2 == 0) {
125+
kernel_choice = 1;
126+
}
127+
switch (kernel_choice) {
128+
case 0: {
129+
const dim3 block(block_size_x, block_size_y);
130+
const dim3 grid(ceil_div(n_cols, block_size_x),
131+
ceil_div(n_rows, block_size_y));
132+
saturated_cast_kernel_single<<<grid, block>>>(
133+
static_cast<nv_bfloat16 *>(input.data_ptr()),
134+
static_cast<__nv_fp8_storage_t *>(output.data_ptr()), n_rows, n_cols,
135+
out_dtype, static_cast<nv_bfloat16 *>(scale.data_ptr()));
136+
break;
137+
}
138+
case 1: {
139+
// / We cast to a 16x2 type, so we need to divide the number of columns by 2
89140
const auto packed_col_size = n_cols / 2;
90141
// Found 4 to be the best factor for the coalesced kernel
91142
const int coarse_factor = 4;
@@ -97,14 +148,20 @@ void dispatch_best_kernel(const Tensor &input, const Tensor &output,
97148
static_cast<__nv_fp8x2_storage_t *>(output.data_ptr()), n_rows,
98149
packed_col_size, out_dtype,
99150
static_cast<nv_bfloat16 *>(scale.data_ptr()));
100-
} else {
101-
const dim3 block(block_size_x, block_size_y);
102-
const dim3 grid(ceil_div(n_cols, block_size_x),
103-
ceil_div(n_rows, block_size_y));
104-
saturated_cast_kernel_single<<<grid, block>>>(
105-
static_cast<nv_bfloat16 *>(input.data_ptr()),
106-
static_cast<__nv_fp8_storage_t *>(output.data_ptr()), n_rows, n_cols,
151+
break;
152+
}
153+
case 2: {
154+
const int coarse_factor = 4;
155+
const dim3 block(256);
156+
const int packed_numel = numel / 2;
157+
// We divide numel by 2 because we are casting to a 16x2 type
158+
const dim3 grid(ceil_div(packed_numel, block.x * coarse_factor));
159+
saturated_cast_kernel_double_coalesced_flat<coarse_factor><<<grid, block>>>(
160+
static_cast<nv_bfloat162 *>(input.data_ptr()),
161+
static_cast<__nv_fp8x2_storage_t *>(output.data_ptr()), packed_numel,
107162
out_dtype, static_cast<nv_bfloat16 *>(scale.data_ptr()));
163+
break;
164+
}
108165
}
109166
C10_CUDA_KERNEL_LAUNCH_CHECK();
110167
}

0 commit comments

Comments
 (0)