Skip to content

Commit 66bd7d1

Browse files
committed
UCX osc: add macro to check for AMO support of data-type size
Signed-off-by: Joseph Schuchart <[email protected]>
1 parent b6a2f4b commit 66bd7d1

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
return OMPI_ERROR; \
2525
}
2626

27+
/* macro to check whether UCX supports atomic operation on the size the operands */
28+
#define ATOMIC_SIZE_SUPPORTED(_size) (sizeof(uint32_t) == _size || \
29+
sizeof(uint64_t) == _size)
30+
2731
typedef struct ucx_iovec {
2832
void *addr;
2933
size_t len;
@@ -384,9 +388,8 @@ bool use_atomic_op(
384388
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
385389
ompi_datatype_type_size(target_dt, &target_dt_bytes);
386390
/* UCX only supports 32 and 64-bit operands atm */
387-
if (sizeof(uint64_t) >= origin_dt_bytes &&
388-
sizeof(uint32_t) <= origin_dt_bytes &&
389-
origin_dt_bytes == target_dt_bytes &&
391+
if (ATOMIC_SIZE_SUPPORTED(origin_dt_bytes) &&
392+
origin_dt_bytes == target_dt_bytes &&
390393
origin_count == target_count) {
391394
return true;
392395
}
@@ -775,7 +778,7 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
775778
}
776779

777780
ompi_datatype_type_size(dt, &dt_bytes);
778-
if (4 == dt_bytes || 8 == dt_bytes) {
781+
if (ATOMIC_SIZE_SUPPORTED(dt_bytes)) {
779782
// fast path using UCX atomic operations
780783
return do_atomic_compare_and_swap(origin_addr, compare_addr,
781784
result_addr, dt, target,
@@ -830,7 +833,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
830833
ompi_datatype_type_size(dt, &dt_bytes);
831834

832835
/* UCX atomics are only supported on 32 and 64 bit values */
833-
if ((64 == dt_bytes || 32 == dt_bytes) &&
836+
if (ATOMIC_SIZE_SUPPORTED(dt_bytes) &&
834837
(op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op ||
835838
op == &ompi_mpi_op_sum.op)) {
836839
uint64_t value;

0 commit comments

Comments
 (0)