22#
33# SPDX-License-Identifier: Apache-2.0
44
5- from dataclasses import dataclass
6- from typing import Union
7-
85from cuda.core.experimental._device import Device
96from cuda.core.experimental._utils.cuda_utils import (
107 CUDAError,
1512)
1613
1714# TODO: revisit this treatment for py313t builds
18- _inited = False
15+ cdef bint _inited = False
16+ cdef bint _use_ex = False
1917
2018
21- def _lazy_init ():
22- global _inited
19+ cdef void _lazy_init() except * :
20+ """ Initialize module-level globals for driver version checks."""
21+ global _inited, _use_ex
2322 if _inited:
2423 return
2524
26- global _use_ex
25+ cdef tuple _py_major_minor
26+ cdef int _driver_ver
27+
2728 # binding availability depends on cuda-python version
2829 _py_major_minor = get_binding_version()
2930 _driver_ver = handle_return(driver.cuDriverGetVersion())
3031 _use_ex = (_driver_ver >= 11080 ) and (_py_major_minor >= (11 , 8 ))
3132 _inited = True
3233
3334
34- @dataclass
35- class LaunchConfig :
35+ cdef class LaunchConfig:
3636 """ Customizable launch options.
3737
3838 Note
@@ -65,21 +65,36 @@ class LaunchConfig:
6565 """
6666
6767 # TODO: expand LaunchConfig to include other attributes
68- grid : Union [tuple , int ] = None
69- cluster : Union [tuple , int ] = None
70- block : Union [tuple , int ] = None
71- shmem_size : int | None = None
72- cooperative_launch : bool | None = False
73-
74- def __post_init__ (self ):
68+ # Note: attributes are declared in _launch_config.pxd
69+
70+ def __init__ (self , grid = None , cluster = None , block = None ,
71+ shmem_size = None , cooperative_launch = False ):
72+ """ Initialize LaunchConfig with validation.
73+
74+ Parameters
75+ ----------
76+ grid : Union[tuple, int], optional
77+ Grid dimensions (number of blocks or clusters if cluster is specified)
78+ cluster : Union[tuple, int], optional
79+ Cluster dimensions (Thread Block Cluster)
80+ block : Union[tuple, int], optional
81+ Block dimensions (threads per block)
82+ shmem_size : int, optional
83+ Dynamic shared memory size in bytes (default: 0)
84+ cooperative_launch : bool, optional
85+ Whether to launch as cooperative kernel (default: False)
86+ """
7587 _lazy_init()
76- self .grid = cast_to_3_tuple ("LaunchConfig.grid" , self .grid )
77- self .block = cast_to_3_tuple ("LaunchConfig.block" , self .block )
88+
89+ # Convert and validate grid and block dimensions
90+ self .grid = cast_to_3_tuple(" LaunchConfig.grid" , grid)
91+ self .block = cast_to_3_tuple(" LaunchConfig.block" , block)
92+
7893 # FIXME: Calling Device() strictly speaking is not quite right; we should instead
7994 # look up the device from stream. We probably need to defer the checks related to
8095 # device compute capability or attributes.
8196 # thread block clusters are supported starting H100
82- if self . cluster is not None :
97+ if cluster is not None :
8398 if not _use_ex:
8499 err, drvers = driver.cuDriverGetVersion()
85100 drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else " "
@@ -89,19 +104,53 @@ def __post_init__(self):
89104 raise CUDAError(
90105 f" thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})"
91106 )
92- self .cluster = cast_to_3_tuple ("LaunchConfig.cluster" , self .cluster )
93- if self .shmem_size is None :
107+ self .cluster = cast_to_3_tuple(" LaunchConfig.cluster" , cluster)
108+ else :
109+ self .cluster = None
110+
111+ # Handle shmem_size default
112+ if shmem_size is None :
94113 self .shmem_size = 0
114+ else :
115+ self .shmem_size = shmem_size
116+
117+ # Handle cooperative_launch
118+ self .cooperative_launch = cooperative_launch
119+
120+ # Validate cooperative launch support
95121 if self .cooperative_launch and not Device().properties.cooperative_launch:
96122 raise CUDAError(" cooperative kernels are not supported on this device" )
97123
124+ def __repr__ (self ):
125+ """ Return string representation of LaunchConfig."""
126+ return (f" LaunchConfig(grid={self.grid}, cluster={self.cluster}, "
127+ f" block={self.block}, shmem_size={self.shmem_size}, "
128+ f" cooperative_launch={self.cooperative_launch})" )
129+
98130
99- def _to_native_launch_config (config : LaunchConfig ) -> driver .CUlaunchConfig :
131+ cpdef object _to_native_launch_config(LaunchConfig config):
132+ """ Convert LaunchConfig to native driver CUlaunchConfig.
133+
134+ Parameters
135+ ----------
136+ config : LaunchConfig
137+ High-level launch configuration
138+
139+ Returns
140+ -------
141+ driver.CUlaunchConfig
142+ Native CUDA driver launch configuration
143+ """
100144 _lazy_init()
101- drv_cfg = driver .CUlaunchConfig ()
145+
146+ cdef object drv_cfg = driver.CUlaunchConfig()
147+ cdef list attrs
148+ cdef object attr
149+ cdef object dim
150+ cdef tuple grid_blocks
102151
103152 # Handle grid dimensions and cluster configuration
104- if config .cluster :
153+ if config.cluster is not None :
105154 # Convert grid from cluster units to block units
106155 grid_blocks = (
107156 config.grid[0 ] * config.cluster[0 ],
@@ -122,11 +171,14 @@ def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
122171
123172 drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
124173 drv_cfg.sharedMemBytes = config.shmem_size
174+
125175 if config.cooperative_launch:
126176 attr = driver.CUlaunchAttribute()
127177 attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_COOPERATIVE
128178 attr.value.cooperative = 1
129179 attrs.append(attr)
180+
130181 drv_cfg.numAttrs = len (attrs)
131182 drv_cfg.attrs = attrs
183+
132184 return drv_cfg
0 commit comments