Skip to content

Commit 4c25b0d

Browse files
committed
Improve tests
1 parent 8d257b6 commit 4c25b0d

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

cuda_bindings/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ test = [
4242
"numpy>=1.21.1",
4343
"pytest>=6.2.4",
4444
"pytest-benchmark>=3.4.1",
45+
"pyglet>=2.1.9"
4546
]
4647

4748
[project.urls]

cuda_bindings/tests/test_graphics_apis.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,30 @@
66

77

88
def test_graphics_api_smoketest():
9-
_ = pytest.importorskip("PySide6")
10-
from PySide6 import QtGui, QtOpenGL
11-
12-
class GLWidget(QtOpenGL.QOpenGLWindow):
13-
def initializeGL(self):
14-
self.m_texture = QtOpenGL.QOpenGLTexture(QtOpenGL.QOpenGLTexture.Target.Target2D)
15-
self.m_texture.setFormat(QtOpenGL.QOpenGLTexture.TextureFormat.RGBA8_UNorm)
16-
self.m_texture.setSize(512, 512)
17-
self.m_texture.allocateStorage()
18-
19-
err, self.gfx_resource = cudart.cudaGraphicsGLRegisterImage(
20-
self.m_texture.textureId(),
21-
self.m_texture.target().value,
22-
cudart.cudaGraphicsRegisterFlags.cudaGraphicsRegisterFlagsWriteDiscard,
23-
)
24-
error_name = cudart.cudaGetErrorName(err)[1].decode()
25-
26-
# We either have everything set up correctly and we get a gfx_resource,
27-
# or we get an error. Either way, we know the API actually did something,
28-
# which is enough for this basic smoketest.
29-
if error_name == "cudaSuccess":
30-
assert int(self.gfx_resource) != 0
31-
else:
32-
assert error_name == "cudaErrorInvalidValue"
33-
34-
app = QtGui.QGuiApplication([])
35-
win = GLWidget()
36-
win.initializeGL()
37-
win.show()
9+
pyglet = pytest.importorskip("pyglet")
10+
11+
tex = pyglet.image.Texture.create(512, 512)
12+
13+
err, gfx_resource = cudart.cudaGraphicsGLRegisterImage(
14+
tex.id, tex.target, cudart.cudaGraphicsRegisterFlags.cudaGraphicsRegisterFlagsWriteDiscard
15+
)
16+
error_name = cudart.cudaGetErrorName(err)[1].decode()
17+
if error_name == "cudaSuccess":
18+
assert int(gfx_resource) != 0
19+
else:
20+
assert error_name in ("cudaErrorInvalidValue", "cudaErrorUnknown")
21+
22+
23+
def test_cuda_register_image_invalid():
24+
"""Exercise cudaGraphicsGLRegisterImage with dummy handle only using CUDA runtime API."""
25+
fake_gl_texture_id = 1
26+
fake_gl_target = 0x0DE1
27+
flags = cudart.cudaGraphicsRegisterFlags.cudaGraphicsRegisterFlagsWriteDiscard
28+
29+
err, resource = cudart.cudaGraphicsGLRegisterImage(fake_gl_texture_id, fake_gl_target, flags)
30+
err_name = cudart.cudaGetErrorName(err)[1].decode()
31+
err_str = cudart.cudaGetErrorString(err)[1].decode()
32+
33+
if err == 0:
34+
cudart.cudaGraphicsUnregisterResource(resource)
35+
raise AssertionError("Expected error from invalid GL texture ID")

0 commit comments

Comments
 (0)