Skip to content

Commit 95047e4

Browse files
committed
Fixes equality and hash to make these consistent. Both include context handle.
1 parent 8c9900d commit 95047e4

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,25 @@ cdef class Stream:
198198
return (0, <uintptr_t>(self._handle))
199199

200200
def __hash__(self) -> int:
201+
# Ensure context is initialized for hash consistency
202+
if self._ctx_handle == CU_CONTEXT_INVALID:
203+
self._get_context()
201204
return hash((type(self), <uintptr_t>(self._ctx_handle), <uintptr_t>(self._handle)))
202205

203206
def __eq__(self, other) -> bool:
204207
if not isinstance(other, Stream):
205208
return NotImplemented
206209
cdef Stream _other = <Stream>other
207-
return <uintptr_t>(self._handle) == <uintptr_t>((_other)._handle)
210+
# Fast path: compare handles first
211+
if <uintptr_t>(self._handle) != <uintptr_t>((_other)._handle):
212+
return False
213+
# Ensure contexts are initialized for both streams
214+
if self._ctx_handle == CU_CONTEXT_INVALID:
215+
self._get_context()
216+
if _other._ctx_handle == CU_CONTEXT_INVALID:
217+
_other._get_context()
218+
# Compare contexts as well
219+
return <uintptr_t>(self._ctx_handle) == <uintptr_t>((_other)._ctx_handle)
208220

209221
@property
210222
def handle(self) -> cuda.bindings.driver.CUstream:

0 commit comments

Comments
 (0)