Skip to content

Commit ae3f200

Browse files
committed
Implementing greptile feedback
1 parent d2856cc commit ae3f200

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

cuda_bindings/tests/helpers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
import pathlib
6+
import sys
7+
8+
CUDA_PATH = os.environ.get("CUDA_PATH")
9+
CUDA_INCLUDE_PATH = None
10+
CCCL_INCLUDE_PATHS = None
11+
if CUDA_PATH is not None:
12+
path = os.path.join(CUDA_PATH, "include")
13+
if os.path.isdir(path):
14+
CUDA_INCLUDE_PATH = path
15+
CCCL_INCLUDE_PATHS = (path,)
16+
path = os.path.join(path, "cccl")
17+
if os.path.isdir(path):
18+
CCCL_INCLUDE_PATHS = (path,) + CCCL_INCLUDE_PATHS
19+
20+
21+
try:
22+
import cuda_python_test_helpers
23+
except ImportError:
24+
# Import shared platform helpers for tests across repos
25+
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[2] / "cuda_python_test_helpers"))
26+
import cuda_python_test_helpers
27+
28+
# If we imported the package instead of the module, get the actual module
29+
if hasattr(cuda_python_test_helpers, '__path__'):
30+
# We imported the package, need to get the actual module
31+
import cuda_python_test_helpers.cuda_python_test_helpers as cuda_python_test_helpers
32+
33+
34+
IS_WSL = cuda_python_test_helpers.IS_WSL
35+
supports_ipc_mempool = cuda_python_test_helpers.supports_ipc_mempool
36+
37+
38+
del cuda_python_test_helpers

cuda_bindings/tests/test_graphics_apis.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pytest
55
from cuda.bindings import runtime as cudart
66

7+
from helpers import IS_WSL
78

9+
@pytest.mark.skipif(IS_WSL, reason="Graphics interop not supported on this platform")
810
def test_graphics_api_smoketest():
911
# Due to lazy importing in pyglet, pytest.importorskip doesn't work
1012
try:
@@ -26,6 +28,7 @@ def test_graphics_api_smoketest():
2628
assert error_name in ("cudaErrorInvalidValue", "cudaErrorUnknown", "cudaErrorOperatingSystem")
2729

2830

31+
@pytest.mark.skipif(IS_WSL, reason="Graphics interop not supported on this platform")
2932
def test_cuda_register_image_invalid():
3033
"""Exercise cudaGraphicsGLRegisterImage with dummy handle only using CUDA runtime API."""
3134
fake_gl_texture_id = 1

0 commit comments

Comments
 (0)