|
15 | 15 | from cuda import cudart as runtime
|
16 | 16 | from cuda import nvrtc
|
17 | 17 |
|
| 18 | +from cuda.core.experimental._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS |
| 19 | +from cuda.core.experimental._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS |
| 20 | + |
18 | 21 |
|
19 | 22 | class CUDAError(Exception):
|
20 | 23 | pass
|
@@ -45,27 +48,45 @@ def cast_to_3_tuple(label, cfg):
|
45 | 48 | return cfg + (1,) * (3 - len(cfg))
|
46 | 49 |
|
47 | 50 |
|
| 51 | +def _check_driver_error(error): |
| 52 | + if error == driver.CUresult.CUDA_SUCCESS: |
| 53 | + return |
| 54 | + name_err, name = driver.cuGetErrorName(error) |
| 55 | + if name_err != driver.CUresult.CUDA_SUCCESS: |
| 56 | + raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") |
| 57 | + name = name.decode() |
| 58 | + expl = DRIVER_CU_RESULT_EXPLANATIONS.get(int(error)) |
| 59 | + if expl is not None: |
| 60 | + raise CUDAError(f"{name}: {expl}") |
| 61 | + desc_err, desc = driver.cuGetErrorString(error) |
| 62 | + if desc_err != driver.CUresult.CUDA_SUCCESS: |
| 63 | + raise CUDAError(f"{name}") |
| 64 | + desc = desc.decode() |
| 65 | + raise CUDAError(f"{name}: {desc}") |
| 66 | + |
| 67 | + |
| 68 | +def _check_runtime_error(error): |
| 69 | + if error == runtime.cudaError_t.cudaSuccess: |
| 70 | + return |
| 71 | + name_err, name = runtime.cudaGetErrorName(error) |
| 72 | + if name_err != runtime.cudaError_t.cudaSuccess: |
| 73 | + raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") |
| 74 | + name = name.decode() |
| 75 | + expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error)) |
| 76 | + if expl is not None: |
| 77 | + raise CUDAError(f"{name}: {expl}") |
| 78 | + desc_err, desc = runtime.cudaGetErrorString(error) |
| 79 | + if desc_err != runtime.cudaError_t.cudaSuccess: |
| 80 | + raise CUDAError(f"{name}") |
| 81 | + desc = desc.decode() |
| 82 | + raise CUDAError(f"{name}: {desc}") |
| 83 | + |
| 84 | + |
48 | 85 | def _check_error(error, handle=None):
|
49 | 86 | if isinstance(error, driver.CUresult):
|
50 |
| - if error == driver.CUresult.CUDA_SUCCESS: |
51 |
| - return |
52 |
| - err, name = driver.cuGetErrorName(error) |
53 |
| - if err == driver.CUresult.CUDA_SUCCESS: |
54 |
| - err, desc = driver.cuGetErrorString(error) |
55 |
| - if err == driver.CUresult.CUDA_SUCCESS: |
56 |
| - raise CUDAError(f"{name.decode()}: {desc.decode()}") |
57 |
| - else: |
58 |
| - raise CUDAError(f"unknown error: {error}") |
| 87 | + _check_driver_error(error) |
59 | 88 | elif isinstance(error, runtime.cudaError_t):
|
60 |
| - if error == runtime.cudaError_t.cudaSuccess: |
61 |
| - return |
62 |
| - err, name = runtime.cudaGetErrorName(error) |
63 |
| - if err == runtime.cudaError_t.cudaSuccess: |
64 |
| - err, desc = runtime.cudaGetErrorString(error) |
65 |
| - if err == runtime.cudaError_t.cudaSuccess: |
66 |
| - raise CUDAError(f"{name.decode()}: {desc.decode()}") |
67 |
| - else: |
68 |
| - raise CUDAError(f"unknown error: {error}") |
| 89 | + _check_runtime_error(error) |
69 | 90 | elif isinstance(error, nvrtc.nvrtcResult):
|
70 | 91 | if error == nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
71 | 92 | return
|
|
0 commit comments