diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index 630b80bb55..3f7ba63a55 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -299,6 +299,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h": cdef extern from "syclinterface/dpctl_sycl_platform_interface.h": + cdef bool DPCTLPlatform_AreEq(const DPCTLSyclPlatformRef, const DPCTLSyclPlatformRef) cdef DPCTLSyclPlatformRef DPCTLPlatform_Copy(const DPCTLSyclPlatformRef) cdef DPCTLSyclPlatformRef DPCTLPlatform_Create() cdef DPCTLSyclPlatformRef DPCTLPlatform_CreateFromSelector( @@ -308,6 +309,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h": cdef const char *DPCTLPlatform_GetName(const DPCTLSyclPlatformRef) cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef) cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef) + cdef size_t DPCTLPlatform_Hash(const DPCTLSyclPlatformRef) cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms() cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext( const DPCTLSyclPlatformRef) diff --git a/dpctl/_sycl_platform.pxd b/dpctl/_sycl_platform.pxd index d8d737d874..2f619203d6 100644 --- a/dpctl/_sycl_platform.pxd +++ b/dpctl/_sycl_platform.pxd @@ -21,6 +21,8 @@ SYCL platform-related helper functions. """ +from libcpp cimport bool + from ._backend cimport DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef @@ -40,6 +42,7 @@ cdef class SyclPlatform(_SyclPlatform): cdef int _init_from_selector(self, DPCTLSyclDeviceSelectorRef DSRef) cdef int _init_from__SyclPlatform(self, _SyclPlatform other) cdef DPCTLSyclPlatformRef get_platform_ref(self) + cdef bool equals(self, SyclPlatform) cpdef list get_platforms() diff --git a/dpctl/_sycl_platform.pyx b/dpctl/_sycl_platform.pyx index 449ad9029f..589c8ae895 100644 --- a/dpctl/_sycl_platform.pyx +++ b/dpctl/_sycl_platform.pyx @@ -21,10 +21,13 @@ """ Implements SyclPlatform Cython extension type. """ +from libcpp cimport bool + from ._backend cimport ( # noqa: E211 DPCTLCString_Delete, DPCTLDeviceSelector_Delete, DPCTLFilterSelector_Create, + DPCTLPlatform_AreEq, DPCTLPlatform_Copy, DPCTLPlatform_Create, DPCTLPlatform_CreateFromSelector, @@ -35,6 +38,7 @@ from ._backend cimport ( # noqa: E211 DPCTLPlatform_GetPlatforms, DPCTLPlatform_GetVendor, DPCTLPlatform_GetVersion, + DPCTLPlatform_Hash, DPCTLPlatformMgr_GetInfo, DPCTLPlatformMgr_PrintInfo, DPCTLPlatformVector_Delete, @@ -274,6 +278,42 @@ cdef class SyclPlatform(_SyclPlatform): else: return SyclContext._create(CRef) + cdef bool equals(self, SyclPlatform other): + """ + Returns true if the :class:`dpctl.SyclPlatform` argument has the + same underlying ``DPCTLSyclPlatformRef`` object as this + :class:`dpctl.SyclPlatform` instance. + + Returns: + :obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects + point to the same ``DPCTLSyclPlatformRef`` object, otherwise + ``False``. + """ + return DPCTLPlatform_AreEq(self._platform_ref, other.get_platform_ref()) + + def __eq__(self, other): + """ + Returns True if the :class:`dpctl.SyclPlatform` argument has the + same underlying ``DPCTLSyclPlatformRef`` object as this + :class:`dpctl.SyclPlatform` instance. + + Returns: + :obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects + point to the same ``DPCTLSyclPlatformRef`` object, otherwise + ``False``. + """ + if isinstance(other, SyclPlatform): + return self.equals( other) + else: + return False + + def __hash__(self): + """ + Returns a hash value by hashing the underlying ``sycl::platform`` object. + + """ + return DPCTLPlatform_Hash(self._platform_ref) + def lsplatform(verbosity=0): """ diff --git a/dpctl/tests/test_sycl_platform.py b/dpctl/tests/test_sycl_platform.py index 39331f9cdc..fa93dbbd12 100644 --- a/dpctl/tests/test_sycl_platform.py +++ b/dpctl/tests/test_sycl_platform.py @@ -17,6 +17,8 @@ """Defines unit test cases for the SyclPlatform class. """ +import sys + import pytest from helper import has_sycl_platforms @@ -88,10 +90,27 @@ def check_repr(platform): def check_default_context(platform): + if "linux" not in sys.platform: + return r = platform.default_context assert type(r) is dpctl.SyclContext +def check_equal_and_hash(platform): + assert platform == platform + if "linux" not in sys.platform: + return + default_ctx = platform.default_context + for d in default_ctx.get_devices(): + assert platform == d.sycl_platform + assert hash(platform) == hash(d.sycl_platform) + + +def check_hash_in_dict(platform): + map = {platform: 0} + assert map[platform] == 0 + + list_of_checks = [ check_name, check_vendor, @@ -99,6 +118,9 @@ def check_default_context(platform): check_backend, check_print_info, check_repr, + check_default_context, + check_equal_and_hash, + check_hash_in_dict, ] diff --git a/libsyclinterface/include/dpctl_sycl_context_interface.h b/libsyclinterface/include/dpctl_sycl_context_interface.h index 6ca3c92d05..ddf80e1e67 100644 --- a/libsyclinterface/include/dpctl_sycl_context_interface.h +++ b/libsyclinterface/include/dpctl_sycl_context_interface.h @@ -159,6 +159,6 @@ void DPCTLContext_Delete(__dpctl_take DPCTLSyclContextRef CtxRef); * @ingroup ContextInterface */ DPCTL_API -size_t DPCTLContext_Hash(__dpctl_take DPCTLSyclContextRef CtxRef); +size_t DPCTLContext_Hash(__dpctl_keep DPCTLSyclContextRef CtxRef); DPCTL_C_EXTERN_C_END diff --git a/libsyclinterface/include/dpctl_sycl_platform_interface.h b/libsyclinterface/include/dpctl_sycl_platform_interface.h index 3e40453e4a..90dab58f1c 100644 --- a/libsyclinterface/include/dpctl_sycl_platform_interface.h +++ b/libsyclinterface/include/dpctl_sycl_platform_interface.h @@ -39,6 +39,19 @@ DPCTL_C_EXTERN_C_BEGIN * @defgroup PlatformInterface Platform class C wrapper */ +/*! + * @brief Checks if two DPCTLSyclPlatformRef objects point to the same + * sycl::platform. + * + * @param PRef1 First opaque pointer to a ``sycl::platform``. + * @param PRef2 Second opaque pointer to a ``sycl::platform``. + * @return True if the underlying sycl::platform are same, false otherwise. + * @ingroup PlatformInterface + */ +DPCTL_API +bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1, + __dpctl_keep const DPCTLSyclPlatformRef PRef2); + /*! * @brief Returns a copy of the DPCTLSyclPlatformRef object. * @@ -155,4 +168,14 @@ DPCTL_API __dpctl_give DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef); +/*! + * @brief Wrapper over std::hash's operator() + * + * @param PRef The DPCTLSyclPlatformRef pointer. + * @return Hash value of the underlying ``sycl::platform`` instance. + * @ingroup PlatformInterface + */ +DPCTL_API +size_t DPCTLPlatform_Hash(__dpctl_keep DPCTLSyclPlatformRef CtxRef); + DPCTL_C_EXTERN_C_END diff --git a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp index 2083259611..5be98b7b61 100644 --- a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp @@ -234,3 +234,27 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef) return nullptr; } } + +bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1, + __dpctl_keep const DPCTLSyclPlatformRef PRef2) +{ + auto P1 = unwrap(PRef1); + auto P2 = unwrap(PRef2); + if (P1 && P2) + return *P1 == *P2; + else + return false; +} + +size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef) +{ + if (PRef) { + auto P = unwrap(PRef); + std::hash hash_fn; + return hash_fn(*P); + } + else { + error_handler("Argument PRef is null.", __FILE__, __func__, __LINE__); + return 0; + } +} diff --git a/libsyclinterface/tests/test_sycl_platform_interface.cpp b/libsyclinterface/tests/test_sycl_platform_interface.cpp index e6fb7df134..f04cead0e1 100644 --- a/libsyclinterface/tests/test_sycl_platform_interface.cpp +++ b/libsyclinterface/tests/test_sycl_platform_interface.cpp @@ -264,6 +264,25 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkPrintInfoNullArg) EXPECT_NO_FATAL_FAILURE(DPCTLPlatformMgr_PrintInfo(Null_PRef, 0)); } +TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEq) +{ + DPCTLSyclPlatformRef PRef_Copy = nullptr; + + EXPECT_NO_FATAL_FAILURE(PRef_Copy = DPCTLPlatform_Copy(PRef)); + + ASSERT_TRUE(DPCTLPlatform_AreEq(PRef, PRef_Copy)); + EXPECT_TRUE(DPCTLPlatform_Hash(PRef) == DPCTLPlatform_Hash(PRef_Copy)); + + EXPECT_NO_FATAL_FAILURE(DPCTLPlatform_Delete(PRef_Copy)); +} + +TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEqNullArg) +{ + DPCTLSyclPlatformRef Null_PRef = nullptr; + ASSERT_FALSE(DPCTLPlatform_AreEq(PRef, Null_PRef)); + ASSERT_TRUE(DPCTLPlatform_Hash(Null_PRef) == 0); +} + TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetName) { check_platform_name(PRef);