Skip to content

Commit ceaa71c

Browse files
committed
set CUDA context
1 parent 57ad040 commit ceaa71c

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchao/csrc/cuda/fp6_llm/fp6_linear.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream,
116116

117117
#include <torch/extension.h>
118118
#include <ATen/ATen.h>
119+
#include <ATen/cuda/CUDAContext.h>
119120
#include <torch/library.h>
120121

121122
namespace torchao {
@@ -166,23 +167,27 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
166167
at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options);
167168
auto Reduction_Workspace = reinterpret_cast<float*>(_workspace.data_ptr<float>()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)
168169

170+
// MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default stream (0)
171+
// this fixes problem with CUDA graphs when used with torch.compile()
172+
auto stream = at::cuda::getCurrentCUDAStream();
173+
169174
// officially supported in Quant-LLM
170175
if (EXPONENT == 3 && MANTISSA == 2)
171-
fpx_linear_kernel<3, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
176+
fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
172177
else if (EXPONENT == 2 && MANTISSA == 2)
173-
fpx_linear_kernel<2, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
178+
fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
174179

175180
// experimental
176181
else if (EXPONENT == 2 && MANTISSA == 3)
177-
fpx_linear_kernel<2, 3>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
182+
fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
178183
else if (EXPONENT == 3 && MANTISSA == 1)
179-
fpx_linear_kernel<3, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
184+
fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
180185
// else if (EXPONENT == 2 && MANTISSA == 1)
181-
// fpx_linear_kernel<2, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
186+
// fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
182187
// else if (EXPONENT == 3 && MANTISSA == 0)
183-
// fpx_linear_kernel<3, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
188+
// fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
184189
// else if (EXPONENT == 2 && MANTISSA == 0)
185-
// fpx_linear_kernel<2, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
190+
// fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK);
186191

187192
else
188193
TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported.");

0 commit comments

Comments
 (0)