@@ -28,7 +28,41 @@ __global__ void saturated_cast_kernel_single(
28
28
}
29
29
}
30
30
31
- template <int coarse_factor>
31
+ template <int coarse_factor>
32
+ __global__ void saturated_cast_kernel_double_coalesced_flat (
33
+ nv_bfloat162 const *__restrict input,
34
+ __nv_fp8x2_storage_t *__restrict output, const int numels,
35
+ __nv_fp8_interpretation_t out_dtype, nv_bfloat16 const *scaler) {
36
+ const int idx = (blockIdx .x * blockDim .x + threadIdx .x ) * coarse_factor;
37
+ const int stride = 1 ;
38
+ const nv_bfloat162 scale_2 = {(*scaler), (*scaler)};
39
+
40
+ nv_bfloat162 scaled_inputs[coarse_factor];
41
+ #pragma unroll
42
+ for (int i{0 }; i < coarse_factor; ++i) {
43
+ const int temp_idx = idx + i;
44
+ if (temp_idx < numels) {
45
+ scaled_inputs[i] = input[temp_idx * stride];
46
+ }
47
+ }
48
+ #pragma unroll
49
+ for (int i{0 }; i < coarse_factor; ++i) {
50
+ const int temp_idx = idx + i;
51
+ if (temp_idx < numels) {
52
+ scaled_inputs[i] = __hmul2 (scaled_inputs[i], scale_2);
53
+ }
54
+ }
55
+ #pragma unroll
56
+ for (int i{0 }; i < coarse_factor; ++i) {
57
+ const int temp_idx = idx + i;
58
+ if (temp_idx < numels) {
59
+ output[temp_idx * stride] = __nv_cvt_bfloat16raw2_to_fp8x2 (
60
+ scaled_inputs[i], __nv_saturation_t ::__NV_SATFINITE, out_dtype);
61
+ }
62
+ }
63
+ }
64
+
65
+ template <int coarse_factor>
32
66
__global__ void saturated_cast_kernel_double_coalesced (
33
67
nv_bfloat162 const *__restrict input,
34
68
__nv_fp8x2_storage_t *__restrict output, int n_rows, int n_cols,
@@ -59,9 +93,8 @@ __global__ void saturated_cast_kernel_double_coalesced(
59
93
const int temp_col = col + i;
60
94
if (row < n_rows && temp_col < n_cols) {
61
95
output[row * row_stride + temp_col * col_stride] =
62
- __nv_cvt_bfloat16raw2_to_fp8x2 (scaled_inputs[i],
63
- __nv_saturation_t ::__NV_SATFINITE,
64
- out_dtype);
96
+ __nv_cvt_bfloat16raw2_to_fp8x2 (
97
+ scaled_inputs[i], __nv_saturation_t ::__NV_SATFINITE, out_dtype);
65
98
}
66
99
}
67
100
}
@@ -84,8 +117,26 @@ void dispatch_best_kernel(const Tensor &input, const Tensor &output,
84
117
const int n_cols = input.size (1 );
85
118
const int block_size_x = 32 ;
86
119
const int block_size_y = 32 ;
87
- if (n_cols % 2 == 0 ) {
88
- // We cast to a 2x8 type, so we need to divide the number of columns by 2
120
+ const auto numel = input.numel ();
121
+ int kernel_choice = 0 ;
122
+ if (numel % 2 == 0 && !transpose) {
123
+ kernel_choice = 2 ;
124
+ } else if (n_cols % 2 == 0 ) {
125
+ kernel_choice = 1 ;
126
+ }
127
+ switch (kernel_choice) {
128
+ case 0 : {
129
+ const dim3 block (block_size_x, block_size_y);
130
+ const dim3 grid (ceil_div (n_cols, block_size_x),
131
+ ceil_div (n_rows, block_size_y));
132
+ saturated_cast_kernel_single<<<grid, block>>> (
133
+ static_cast <nv_bfloat16 *>(input.data_ptr ()),
134
+ static_cast <__nv_fp8_storage_t *>(output.data_ptr ()), n_rows, n_cols,
135
+ out_dtype, static_cast <nv_bfloat16 *>(scale.data_ptr ()));
136
+ break ;
137
+ }
138
+ case 1 : {
139
+ // / We cast to a 16x2 type, so we need to divide the number of columns by 2
89
140
const auto packed_col_size = n_cols / 2 ;
90
141
// Found 4 to be the best factor for the coalesced kernel
91
142
const int coarse_factor = 4 ;
@@ -97,14 +148,20 @@ void dispatch_best_kernel(const Tensor &input, const Tensor &output,
97
148
static_cast <__nv_fp8x2_storage_t *>(output.data_ptr ()), n_rows,
98
149
packed_col_size, out_dtype,
99
150
static_cast <nv_bfloat16 *>(scale.data_ptr ()));
100
- } else {
101
- const dim3 block (block_size_x, block_size_y);
102
- const dim3 grid (ceil_div (n_cols, block_size_x),
103
- ceil_div (n_rows, block_size_y));
104
- saturated_cast_kernel_single<<<grid, block>>> (
105
- static_cast <nv_bfloat16 *>(input.data_ptr ()),
106
- static_cast <__nv_fp8_storage_t *>(output.data_ptr ()), n_rows, n_cols,
151
+ break ;
152
+ }
153
+ case 2 : {
154
+ const int coarse_factor = 4 ;
155
+ const dim3 block (256 );
156
+ const int packed_numel = numel / 2 ;
157
+ // We divide numel by 2 because we are casting to a 16x2 type
158
+ const dim3 grid (ceil_div (packed_numel, block.x * coarse_factor));
159
+ saturated_cast_kernel_double_coalesced_flat<coarse_factor><<<grid, block>>> (
160
+ static_cast <nv_bfloat162 *>(input.data_ptr ()),
161
+ static_cast <__nv_fp8x2_storage_t *>(output.data_ptr ()), packed_numel,
107
162
out_dtype, static_cast <nv_bfloat16 *>(scale.data_ptr ()));
163
+ break ;
164
+ }
108
165
}
109
166
C10_CUDA_KERNEL_LAUNCH_CHECK ();
110
167
}
0 commit comments