From 0787d13f1ea522e21fc86607ae85bc310351e216 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 10 Aug 2023 07:29:50 -0500 Subject: [PATCH 1/5] Implement DPCTLPlatform_AreEq(PRef1, PRef1), DPCTLPlatform_Hash(PRef) Added declaration, doxygen docs, implementation and tests. --- .../include/dpctl_sycl_platform_interface.h | 23 ++++++++++++++++++ .../source/dpctl_sycl_platform_interface.cpp | 24 +++++++++++++++++++ .../tests/test_sycl_platform_interface.cpp | 19 +++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/libsyclinterface/include/dpctl_sycl_platform_interface.h b/libsyclinterface/include/dpctl_sycl_platform_interface.h index 3e40453e4a..90dab58f1c 100644 --- a/libsyclinterface/include/dpctl_sycl_platform_interface.h +++ b/libsyclinterface/include/dpctl_sycl_platform_interface.h @@ -39,6 +39,19 @@ DPCTL_C_EXTERN_C_BEGIN * @defgroup PlatformInterface Platform class C wrapper */ +/*! + * @brief Checks if two DPCTLSyclPlatformRef objects point to the same + * sycl::platform. + * + * @param PRef1 First opaque pointer to a ``sycl::platform``. + * @param PRef2 Second opaque pointer to a ``sycl::platform``. + * @return True if the underlying sycl::platform are same, false otherwise. + * @ingroup PlatformInterface + */ +DPCTL_API +bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1, + __dpctl_keep const DPCTLSyclPlatformRef PRef2); + /*! * @brief Returns a copy of the DPCTLSyclPlatformRef object. * @@ -155,4 +168,14 @@ DPCTL_API __dpctl_give DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef); +/*! + * @brief Wrapper over std::hash's operator() + * + * @param PRef The DPCTLSyclPlatformRef pointer. + * @return Hash value of the underlying ``sycl::platform`` instance. + * @ingroup PlatformInterface + */ +DPCTL_API +size_t DPCTLPlatform_Hash(__dpctl_keep DPCTLSyclPlatformRef CtxRef); + DPCTL_C_EXTERN_C_END diff --git a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp index 2083259611..5be98b7b61 100644 --- a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp @@ -234,3 +234,27 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef) return nullptr; } } + +bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1, + __dpctl_keep const DPCTLSyclPlatformRef PRef2) +{ + auto P1 = unwrap(PRef1); + auto P2 = unwrap(PRef2); + if (P1 && P2) + return *P1 == *P2; + else + return false; +} + +size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef) +{ + if (PRef) { + auto P = unwrap(PRef); + std::hash hash_fn; + return hash_fn(*P); + } + else { + error_handler("Argument PRef is null.", __FILE__, __func__, __LINE__); + return 0; + } +} diff --git a/libsyclinterface/tests/test_sycl_platform_interface.cpp b/libsyclinterface/tests/test_sycl_platform_interface.cpp index e6fb7df134..f04cead0e1 100644 --- a/libsyclinterface/tests/test_sycl_platform_interface.cpp +++ b/libsyclinterface/tests/test_sycl_platform_interface.cpp @@ -264,6 +264,25 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkPrintInfoNullArg) EXPECT_NO_FATAL_FAILURE(DPCTLPlatformMgr_PrintInfo(Null_PRef, 0)); } +TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEq) +{ + DPCTLSyclPlatformRef PRef_Copy = nullptr; + + EXPECT_NO_FATAL_FAILURE(PRef_Copy = DPCTLPlatform_Copy(PRef)); + + ASSERT_TRUE(DPCTLPlatform_AreEq(PRef, PRef_Copy)); + EXPECT_TRUE(DPCTLPlatform_Hash(PRef) == DPCTLPlatform_Hash(PRef_Copy)); + + EXPECT_NO_FATAL_FAILURE(DPCTLPlatform_Delete(PRef_Copy)); +} + +TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEqNullArg) +{ + DPCTLSyclPlatformRef Null_PRef = nullptr; + ASSERT_FALSE(DPCTLPlatform_AreEq(PRef, Null_PRef)); + ASSERT_TRUE(DPCTLPlatform_Hash(Null_PRef) == 0); +} + TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetName) { check_platform_name(PRef); From 358fc5d3d930c2195091d34e3471945c125f4f10 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 10 Aug 2023 08:51:48 -0500 Subject: [PATCH 2/5] Corrected declaration of DPCLTContext_Hash --- libsyclinterface/include/dpctl_sycl_context_interface.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsyclinterface/include/dpctl_sycl_context_interface.h b/libsyclinterface/include/dpctl_sycl_context_interface.h index 6ca3c92d05..ddf80e1e67 100644 --- a/libsyclinterface/include/dpctl_sycl_context_interface.h +++ b/libsyclinterface/include/dpctl_sycl_context_interface.h @@ -159,6 +159,6 @@ void DPCTLContext_Delete(__dpctl_take DPCTLSyclContextRef CtxRef); * @ingroup ContextInterface */ DPCTL_API -size_t DPCTLContext_Hash(__dpctl_take DPCTLSyclContextRef CtxRef); +size_t DPCTLContext_Hash(__dpctl_keep DPCTLSyclContextRef CtxRef); DPCTL_C_EXTERN_C_END From 0bfc6f8d6f891a6b61e0d2148a49f530295aa622 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 10 Aug 2023 08:52:10 -0500 Subject: [PATCH 3/5] Implemented SyclPlatform.__eq__ and SyclPlatform.__hash__ --- dpctl/_backend.pxd | 2 ++ dpctl/_sycl_platform.pxd | 3 +++ dpctl/_sycl_platform.pyx | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index 630b80bb55..3f7ba63a55 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -299,6 +299,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h": cdef extern from "syclinterface/dpctl_sycl_platform_interface.h": + cdef bool DPCTLPlatform_AreEq(const DPCTLSyclPlatformRef, const DPCTLSyclPlatformRef) cdef DPCTLSyclPlatformRef DPCTLPlatform_Copy(const DPCTLSyclPlatformRef) cdef DPCTLSyclPlatformRef DPCTLPlatform_Create() cdef DPCTLSyclPlatformRef DPCTLPlatform_CreateFromSelector( @@ -308,6 +309,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h": cdef const char *DPCTLPlatform_GetName(const DPCTLSyclPlatformRef) cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef) cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef) + cdef size_t DPCTLPlatform_Hash(const DPCTLSyclPlatformRef) cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms() cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext( const DPCTLSyclPlatformRef) diff --git a/dpctl/_sycl_platform.pxd b/dpctl/_sycl_platform.pxd index d8d737d874..2f619203d6 100644 --- a/dpctl/_sycl_platform.pxd +++ b/dpctl/_sycl_platform.pxd @@ -21,6 +21,8 @@ SYCL platform-related helper functions. """ +from libcpp cimport bool + from ._backend cimport DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef @@ -40,6 +42,7 @@ cdef class SyclPlatform(_SyclPlatform): cdef int _init_from_selector(self, DPCTLSyclDeviceSelectorRef DSRef) cdef int _init_from__SyclPlatform(self, _SyclPlatform other) cdef DPCTLSyclPlatformRef get_platform_ref(self) + cdef bool equals(self, SyclPlatform) cpdef list get_platforms() diff --git a/dpctl/_sycl_platform.pyx b/dpctl/_sycl_platform.pyx index 449ad9029f..589c8ae895 100644 --- a/dpctl/_sycl_platform.pyx +++ b/dpctl/_sycl_platform.pyx @@ -21,10 +21,13 @@ """ Implements SyclPlatform Cython extension type. """ +from libcpp cimport bool + from ._backend cimport ( # noqa: E211 DPCTLCString_Delete, DPCTLDeviceSelector_Delete, DPCTLFilterSelector_Create, + DPCTLPlatform_AreEq, DPCTLPlatform_Copy, DPCTLPlatform_Create, DPCTLPlatform_CreateFromSelector, @@ -35,6 +38,7 @@ from ._backend cimport ( # noqa: E211 DPCTLPlatform_GetPlatforms, DPCTLPlatform_GetVendor, DPCTLPlatform_GetVersion, + DPCTLPlatform_Hash, DPCTLPlatformMgr_GetInfo, DPCTLPlatformMgr_PrintInfo, DPCTLPlatformVector_Delete, @@ -274,6 +278,42 @@ cdef class SyclPlatform(_SyclPlatform): else: return SyclContext._create(CRef) + cdef bool equals(self, SyclPlatform other): + """ + Returns true if the :class:`dpctl.SyclPlatform` argument has the + same underlying ``DPCTLSyclPlatformRef`` object as this + :class:`dpctl.SyclPlatform` instance. + + Returns: + :obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects + point to the same ``DPCTLSyclPlatformRef`` object, otherwise + ``False``. + """ + return DPCTLPlatform_AreEq(self._platform_ref, other.get_platform_ref()) + + def __eq__(self, other): + """ + Returns True if the :class:`dpctl.SyclPlatform` argument has the + same underlying ``DPCTLSyclPlatformRef`` object as this + :class:`dpctl.SyclPlatform` instance. + + Returns: + :obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects + point to the same ``DPCTLSyclPlatformRef`` object, otherwise + ``False``. + """ + if isinstance(other, SyclPlatform): + return self.equals( other) + else: + return False + + def __hash__(self): + """ + Returns a hash value by hashing the underlying ``sycl::platform`` object. + + """ + return DPCTLPlatform_Hash(self._platform_ref) + def lsplatform(verbosity=0): """ From b40d357d40628b5ff6e7509cd1cfcc9cc2068e47 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 10 Aug 2023 09:07:36 -0500 Subject: [PATCH 4/5] Added tests for equality testing and hashing of SyclPlatform Also enabled overlooked check_default_context --- dpctl/tests/test_sycl_platform.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/dpctl/tests/test_sycl_platform.py b/dpctl/tests/test_sycl_platform.py index 39331f9cdc..6b61160f9d 100644 --- a/dpctl/tests/test_sycl_platform.py +++ b/dpctl/tests/test_sycl_platform.py @@ -92,6 +92,19 @@ def check_default_context(platform): assert type(r) is dpctl.SyclContext +def check_equal_and_hash(platform): + assert platform == platform + default_ctx = platform.default_context + for d in default_ctx.get_devices(): + assert platform == d.sycl_platform + assert hash(platform) == hash(d.sycl_platform) + + +def check_hash_in_dict(platform): + map = {platform: 0} + assert map[platform] == 0 + + list_of_checks = [ check_name, check_vendor, @@ -99,6 +112,9 @@ def check_default_context(platform): check_backend, check_print_info, check_repr, + check_default_context, + check_equal_and_hash, + check_hash_in_dict, ] From e54aaa0d44073b760dd6244af3124f077e4ce982 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 10 Aug 2023 14:39:44 -0500 Subject: [PATCH 5/5] Exercise default_sycl_platform on Linux only It is not yet supported on Windows. --- dpctl/tests/test_sycl_platform.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dpctl/tests/test_sycl_platform.py b/dpctl/tests/test_sycl_platform.py index 6b61160f9d..fa93dbbd12 100644 --- a/dpctl/tests/test_sycl_platform.py +++ b/dpctl/tests/test_sycl_platform.py @@ -17,6 +17,8 @@ """Defines unit test cases for the SyclPlatform class. """ +import sys + import pytest from helper import has_sycl_platforms @@ -88,12 +90,16 @@ def check_repr(platform): def check_default_context(platform): + if "linux" not in sys.platform: + return r = platform.default_context assert type(r) is dpctl.SyclContext def check_equal_and_hash(platform): assert platform == platform + if "linux" not in sys.platform: + return default_ctx = platform.default_context for d in default_ctx.get_devices(): assert platform == d.sycl_platform