Skip to content

Commit 69aac67

Browse files
rparolinCopilot
andauthored
cythonize _launch_config.py (#1221)
* cythonize _launch_config * pre-commit * moving back * removing comment * reverting list code back to original * pre-commit * removing unused import * Update cuda_core/cuda/core/experimental/_launch_config.pyx Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent 02588a6 commit 69aac67

File tree

4 files changed

+97
-26
lines changed

4 files changed

+97
-26
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
cdef bint _inited
7+
cdef bint _use_ex
8+
9+
cdef void _lazy_init() except *
10+
11+
cdef class LaunchConfig:
12+
"""Customizable launch options."""
13+
cdef public tuple grid
14+
cdef public tuple cluster
15+
cdef public tuple block
16+
cdef public int shmem_size
17+
cdef public bint cooperative_launch
18+
19+
cpdef object _to_native_launch_config(LaunchConfig config)

cuda_core/cuda/core/experimental/_launch_config.py renamed to cuda_core/cuda/core/experimental/_launch_config.pyx

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from dataclasses import dataclass
6-
from typing import Union
7-
85
from cuda.core.experimental._device import Device
96
from cuda.core.experimental._utils.cuda_utils import (
107
CUDAError,
@@ -15,24 +12,27 @@
1512
)
1613

1714
# TODO: revisit this treatment for py313t builds
18-
_inited = False
15+
cdef bint _inited = False
16+
cdef bint _use_ex = False
1917

2018

21-
def _lazy_init():
22-
global _inited
19+
cdef void _lazy_init() except *:
20+
"""Initialize module-level globals for driver version checks."""
21+
global _inited, _use_ex
2322
if _inited:
2423
return
2524

26-
global _use_ex
25+
cdef tuple _py_major_minor
26+
cdef int _driver_ver
27+
2728
# binding availability depends on cuda-python version
2829
_py_major_minor = get_binding_version()
2930
_driver_ver = handle_return(driver.cuDriverGetVersion())
3031
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
3132
_inited = True
3233

3334

34-
@dataclass
35-
class LaunchConfig:
35+
cdef class LaunchConfig:
3636
"""Customizable launch options.
3737
3838
Note
@@ -65,21 +65,36 @@ class LaunchConfig:
6565
"""
6666

6767
# TODO: expand LaunchConfig to include other attributes
68-
grid: Union[tuple, int] = None
69-
cluster: Union[tuple, int] = None
70-
block: Union[tuple, int] = None
71-
shmem_size: int | None = None
72-
cooperative_launch: bool | None = False
73-
74-
def __post_init__(self):
68+
# Note: attributes are declared in _launch_config.pxd
69+
70+
def __init__(self, grid=None, cluster=None, block=None,
71+
shmem_size=None, cooperative_launch=False):
72+
"""Initialize LaunchConfig with validation.
73+
74+
Parameters
75+
----------
76+
grid : Union[tuple, int], optional
77+
Grid dimensions (number of blocks or clusters if cluster is specified)
78+
cluster : Union[tuple, int], optional
79+
Cluster dimensions (Thread Block Cluster)
80+
block : Union[tuple, int], optional
81+
Block dimensions (threads per block)
82+
shmem_size : int, optional
83+
Dynamic shared memory size in bytes (default: 0)
84+
cooperative_launch : bool, optional
85+
Whether to launch as cooperative kernel (default: False)
86+
"""
7587
_lazy_init()
76-
self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid)
77-
self.block = cast_to_3_tuple("LaunchConfig.block", self.block)
88+
89+
# Convert and validate grid and block dimensions
90+
self.grid = cast_to_3_tuple("LaunchConfig.grid", grid)
91+
self.block = cast_to_3_tuple("LaunchConfig.block", block)
92+
7893
# FIXME: Calling Device() strictly speaking is not quite right; we should instead
7994
# look up the device from stream. We probably need to defer the checks related to
8095
# device compute capability or attributes.
8196
# thread block clusters are supported starting H100
82-
if self.cluster is not None:
97+
if cluster is not None:
8398
if not _use_ex:
8499
err, drvers = driver.cuDriverGetVersion()
85100
drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else ""
@@ -89,19 +104,53 @@ def __post_init__(self):
89104
raise CUDAError(
90105
f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})"
91106
)
92-
self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster)
93-
if self.shmem_size is None:
107+
self.cluster = cast_to_3_tuple("LaunchConfig.cluster", cluster)
108+
else:
109+
self.cluster = None
110+
111+
# Handle shmem_size default
112+
if shmem_size is None:
94113
self.shmem_size = 0
114+
else:
115+
self.shmem_size = shmem_size
116+
117+
# Handle cooperative_launch
118+
self.cooperative_launch = cooperative_launch
119+
120+
# Validate cooperative launch support
95121
if self.cooperative_launch and not Device().properties.cooperative_launch:
96122
raise CUDAError("cooperative kernels are not supported on this device")
97123

124+
def __repr__(self):
125+
"""Return string representation of LaunchConfig."""
126+
return (f"LaunchConfig(grid={self.grid}, cluster={self.cluster}, "
127+
f"block={self.block}, shmem_size={self.shmem_size}, "
128+
f"cooperative_launch={self.cooperative_launch})")
129+
98130

99-
def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
131+
cpdef object _to_native_launch_config(LaunchConfig config):
132+
"""Convert LaunchConfig to native driver CUlaunchConfig.
133+
134+
Parameters
135+
----------
136+
config : LaunchConfig
137+
High-level launch configuration
138+
139+
Returns
140+
-------
141+
driver.CUlaunchConfig
142+
Native CUDA driver launch configuration
143+
"""
100144
_lazy_init()
101-
drv_cfg = driver.CUlaunchConfig()
145+
146+
cdef object drv_cfg = driver.CUlaunchConfig()
147+
cdef list attrs
148+
cdef object attr
149+
cdef object dim
150+
cdef tuple grid_blocks
102151

103152
# Handle grid dimensions and cluster configuration
104-
if config.cluster:
153+
if config.cluster is not None:
105154
# Convert grid from cluster units to block units
106155
grid_blocks = (
107156
config.grid[0] * config.cluster[0],
@@ -122,11 +171,14 @@ def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
122171

123172
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
124173
drv_cfg.sharedMemBytes = config.shmem_size
174+
125175
if config.cooperative_launch:
126176
attr = driver.CUlaunchAttribute()
127177
attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_COOPERATIVE
128178
attr.value.cooperative = 1
129179
attrs.append(attr)
180+
130181
drv_cfg.numAttrs = len(attrs)
131182
drv_cfg.attrs = attrs
183+
132184
return drv_cfg

cuda_core/cuda/core/experimental/_launcher.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from cuda.core.experimental._stream cimport _try_to_get_stream_ptr
99
from typing import Union
1010

1111
from cuda.core.experimental._kernel_arg_handler import ParamHolder
12-
from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config
12+
from cuda.core.experimental._launch_config cimport LaunchConfig, _to_native_launch_config
1313
from cuda.core.experimental._module import Kernel
1414
from cuda.core.experimental._stream import IsStreamT, Stream
1515
from cuda.core.experimental._utils.clear_error_support import assert_type

cuda_core/tests/test_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_launch_invalid_values(init_cuda):
131131
ker = mod.get_kernel("my_kernel")
132132
config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), shmem_size=0)
133133

134-
with pytest.raises(ValueError):
134+
with pytest.raises(TypeError):
135135
launch(None, ker, config)
136136

137137
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)