@@ -116,6 +116,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream,
116
116
117
117
#include < torch/extension.h>
118
118
#include < ATen/ATen.h>
119
+ #include < ATen/cuda/CUDAContext.h>
119
120
#include < torch/library.h>
120
121
121
122
namespace torchao {
@@ -166,23 +167,27 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
166
167
at::Tensor _workspace = torch::empty ({splitK, num_in_feats, num_out_channels}, options);
167
168
auto Reduction_Workspace = reinterpret_cast <float *>(_workspace.data_ptr <float >()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)
168
169
170
+ // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default stream (0)
171
+ // this fixes problem with CUDA graphs when used with torch.compile()
172
+ auto stream = at::cuda::getCurrentCUDAStream ();
173
+
169
174
// officially supported in Quant-LLM
170
175
if (EXPONENT == 3 && MANTISSA == 2 )
171
- fpx_linear_kernel<3 , 2 >(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
176
+ fpx_linear_kernel<3 , 2 >(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
172
177
else if (EXPONENT == 2 && MANTISSA == 2 )
173
- fpx_linear_kernel<2 , 2 >(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
178
+ fpx_linear_kernel<2 , 2 >(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
174
179
175
180
// experimental
176
181
else if (EXPONENT == 2 && MANTISSA == 3 )
177
- fpx_linear_kernel<2 , 3 >(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
182
+ fpx_linear_kernel<2 , 3 >(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
178
183
else if (EXPONENT == 3 && MANTISSA == 1 )
179
- fpx_linear_kernel<3 , 1 >(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
184
+ fpx_linear_kernel<3 , 1 >(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
180
185
// else if (EXPONENT == 2 && MANTISSA == 1)
181
- // fpx_linear_kernel<2, 1>(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
186
+ // fpx_linear_kernel<2, 1>(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
182
187
// else if (EXPONENT == 3 && MANTISSA == 0)
183
- // fpx_linear_kernel<3, 0>(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
188
+ // fpx_linear_kernel<3, 0>(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
184
189
// else if (EXPONENT == 2 && MANTISSA == 0)
185
- // fpx_linear_kernel<2, 0>(0 , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
190
+ // fpx_linear_kernel<2, 0>(stream , weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
186
191
187
192
else
188
193
TORCH_CHECK (false , " FP" , NBITS, " E" , EXPONENT, " M" , MANTISSA, " is not supported." );
0 commit comments