2424 return OMPI_ERROR; \
2525 }
2626
27+ /* macro to check whether UCX supports atomic operation on the size the operands */
28+ #define ATOMIC_SIZE_SUPPORTED (_size ) (sizeof(uint32_t) == _size || \
29+ sizeof(uint64_t) == _size)
30+
2731typedef struct ucx_iovec {
2832 void * addr ;
2933 size_t len ;
@@ -384,9 +388,8 @@ bool use_atomic_op(
384388 ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
385389 ompi_datatype_type_size (target_dt , & target_dt_bytes );
386390 /* UCX only supports 32 and 64-bit operands atm */
387- if (sizeof (uint64_t ) >= origin_dt_bytes &&
388- sizeof (uint32_t ) <= origin_dt_bytes &&
389- origin_dt_bytes == target_dt_bytes &&
391+ if (ATOMIC_SIZE_SUPPORTED (origin_dt_bytes ) &&
392+ origin_dt_bytes == target_dt_bytes &&
390393 origin_count == target_count ) {
391394 return true;
392395 }
@@ -775,7 +778,7 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
775778 }
776779
777780 ompi_datatype_type_size (dt , & dt_bytes );
778- if (4 == dt_bytes || 8 == dt_bytes ) {
781+ if (ATOMIC_SIZE_SUPPORTED ( dt_bytes ) ) {
779782 // fast path using UCX atomic operations
780783 return do_atomic_compare_and_swap (origin_addr , compare_addr ,
781784 result_addr , dt , target ,
@@ -818,6 +821,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
818821 struct ompi_datatype_t * dt , int target ,
819822 ptrdiff_t target_disp , struct ompi_op_t * op ,
820823 struct ompi_win_t * win ) {
824+ size_t dt_bytes ;
821825 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
822826 int ret = OMPI_SUCCESS ;
823827
@@ -826,12 +830,15 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
826830 return ret ;
827831 }
828832
829- if (op == & ompi_mpi_op_no_op .op || op == & ompi_mpi_op_replace .op ||
830- op == & ompi_mpi_op_sum .op ) {
833+ ompi_datatype_type_size (dt , & dt_bytes );
834+
835+ /* UCX atomics are only supported on 32 and 64 bit values */
836+ if (ATOMIC_SIZE_SUPPORTED (dt_bytes ) &&
837+ (op == & ompi_mpi_op_no_op .op || op == & ompi_mpi_op_replace .op ||
838+ op == & ompi_mpi_op_sum .op )) {
831839 uint64_t value ;
832840 uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
833841 ucp_atomic_fetch_op_t opcode ;
834- size_t dt_bytes ;
835842 bool lock_acquired = false;
836843
837844 if (!module -> acc_single_intrinsic ) {
0 commit comments