Skip to content

Commit 710ff57

Browse files
authored
Merge pull request #10709 from MamziB/mamzi/single-thread-enhancements-3
OSC/UCX: Adding the following optimizations (nonblocking accumulate and reusing resources)
2 parents d13b25d + 1ea6fb9 commit 710ff57

File tree

9 files changed

+1073
-270
lines changed

9 files changed

+1073
-270
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@
2727
#define OMPI_OSC_UCX_ATTACH_MAX 48
2828
#define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024
2929

30+
3031
typedef struct ompi_osc_ucx_component {
3132
ompi_osc_base_component_t super;
3233
opal_common_ucx_wpool_t *wpool;
3334
bool enable_mpi_threads;
3435
opal_free_list_t requests; /* request free list for the r* communication variants */
36+
opal_free_list_t accumulate_requests; /* request free list for the r* communication variants */
3537
bool env_initialized; /* UCX environment is initialized or not */
36-
int num_incomplete_req_ops;
38+
int comm_world_size;
39+
ucp_ep_h *endpoints;
3740
int num_modules;
3841
bool no_locks; /* Default value of the no_locks info key for new windows */
3942
bool acc_single_intrinsic;
@@ -44,6 +47,16 @@ typedef struct ompi_osc_ucx_component {
4447

4548
OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component;
4649

50+
#define OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(_module) \
51+
do { \
52+
opal_atomic_add_fetch_size_t(&_module->ctx->num_incomplete_req_ops, 1); \
53+
} while(0);
54+
55+
#define OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS(_module) \
56+
do { \
57+
opal_atomic_add_fetch_size_t(&_module->ctx->num_incomplete_req_ops, -1); \
58+
} while(0);
59+
4760
typedef enum ompi_osc_ucx_epoch {
4861
NONE_EPOCH,
4962
FENCE_EPOCH,
@@ -69,7 +82,8 @@ typedef struct ompi_osc_ucx_epoch_type {
6982
#define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3)
7083
#define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4)
7184
#define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5)
72-
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX))
85+
#define OSC_UCX_STATE_DYNAMIC_LOCK_OFFSET (sizeof(uint64_t) * 6)
86+
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (6 + OMPI_OSC_UCX_POST_PEER_MAX))
7387

7488
typedef struct ompi_osc_dynamic_win_info {
7589
uint64_t base;
@@ -102,6 +116,7 @@ typedef struct ompi_osc_ucx_module {
102116
size_t size;
103117
uint64_t *addrs;
104118
uint64_t *state_addrs;
119+
uint64_t *comm_world_ranks;
105120
int disp_unit; /* if disp_unit >= 0, then everyone has the same
106121
* disp unit size; if disp_unit == -1, then we
107122
* need to look at disp_units */
@@ -125,6 +140,7 @@ typedef struct ompi_osc_ucx_module {
125140
opal_common_ucx_wpmem_t *mem;
126141
opal_common_ucx_wpmem_t *state_mem;
127142

143+
bool skip_sync_check;
128144
bool noncontig_shared_win;
129145
size_t *sizes;
130146
/* in shared windows, shmem_addrs can be used for direct load store to
@@ -147,9 +163,18 @@ typedef struct ompi_osc_ucx_lock {
147163
bool is_nocheck;
148164
} ompi_osc_ucx_lock_t;
149165

150-
#define OSC_UCX_GET_EP(comm_, rank_) (ompi_comm_peer_lookup(comm_, rank_)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_UCX])
166+
#define OSC_UCX_GET_EP(_module, rank_) (mca_osc_ucx_component.endpoints[_module->comm_world_ranks[rank_]])
151167
#define OSC_UCX_GET_DISP(module_, rank_) ((module_->disp_unit < 0) ? module_->disp_units[rank_] : module_->disp_unit)
152168

169+
#define OSC_UCX_GET_DEFAULT_EP(_ep_ptr, _module, _target) \
170+
if (opal_common_ucx_thread_enabled) { \
171+
_ep_ptr = NULL; \
172+
} else { \
173+
_ep_ptr = (ucp_ep_h *)&(OSC_UCX_GET_EP(_module, _target)); \
174+
}
175+
176+
extern size_t ompi_osc_ucx_outstanding_ops_flush_threshold;
177+
153178
int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size,
154179
int *disp_unit, void * baseptr);
155180
int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len);
@@ -169,6 +194,11 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
169194
int target, ptrdiff_t target_disp, int target_count,
170195
struct ompi_datatype_t *target_dt,
171196
struct ompi_op_t *op, struct ompi_win_t *win);
197+
int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
198+
struct ompi_datatype_t *origin_dt,
199+
int target, ptrdiff_t target_disp, int target_count,
200+
struct ompi_datatype_t *target_dt,
201+
struct ompi_op_t *op, struct ompi_win_t *win);
172202
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
173203
void *result_addr, struct ompi_datatype_t *dt,
174204
int target, ptrdiff_t target_disp,
@@ -184,6 +214,13 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
184214
int target_rank, ptrdiff_t target_disp,
185215
int target_count, struct ompi_datatype_t *target_datatype,
186216
struct ompi_op_t *op, struct ompi_win_t *win);
217+
int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
218+
struct ompi_datatype_t *origin_datatype,
219+
void *result_addr, int result_count,
220+
struct ompi_datatype_t *result_datatype,
221+
int target_rank, ptrdiff_t target_disp,
222+
int target_count, struct ompi_datatype_t *target_datatype,
223+
struct ompi_op_t *op, struct ompi_win_t *win);
187224
int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
188225
struct ompi_datatype_t *origin_dt,
189226
int target, ptrdiff_t target_disp, int target_count,
@@ -228,10 +265,7 @@ int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win);
228265
int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins,
229266
int min_index, int max_index,
230267
uint64_t base, size_t len, int *insert);
231-
extern inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target);
232-
extern inline int ompi_osc_state_lock(ompi_osc_ucx_module_t *module, int target,
233-
bool *lock_acquired, bool force_lock);
234-
extern inline int ompi_osc_state_unlock(ompi_osc_ucx_module_t *module, int target,
235-
bool lock_acquired, void *free_ptr);
268+
int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target);
269+
int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target);
236270

237271
#endif /* OMPI_OSC_UCX_H */

ompi/mca/osc/ucx/osc_ucx_active_target.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,31 +165,33 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {
165165
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
166166
int i, size;
167167
int ret = OMPI_SUCCESS;
168+
ucp_ep_h *ep;
168169

169170
if (module->epoch_type.access != START_COMPLETE_EPOCH) {
170171
return OMPI_ERR_RMA_SYNC;
171172
}
172173

173-
module->epoch_type.access = NONE_EPOCH;
174-
175174
ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/);
176175
if (ret != OMPI_SUCCESS) {
177176
return ret;
178177
}
179178

179+
module->epoch_type.access = NONE_EPOCH;
180+
180181
size = ompi_group_size(module->start_group);
181182
for (i = 0; i < size; i++) {
182183
uint64_t remote_addr = module->state_addrs[module->start_grp_ranks[i]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; // write to state.complete_count on remote side
183184

185+
OSC_UCX_GET_DEFAULT_EP(ep, module, module->start_grp_ranks[i]);
186+
184187
ret = opal_common_ucx_wpmem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD,
185188
1, module->start_grp_ranks[i], sizeof(uint64_t),
186-
remote_addr);
189+
remote_addr, ep);
187190
if (ret != OMPI_SUCCESS) {
188191
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_post failed: %d", ret);
189192
}
190193

191-
ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP,
192-
module->start_grp_ranks[i]);
194+
ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, module->start_grp_ranks[i]);
193195
if (ret != OMPI_SUCCESS) {
194196
return ret;
195197
}
@@ -204,6 +206,7 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {
204206

205207
int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_win_t *win) {
206208
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
209+
ucp_ep_h *ep;
207210
int ret = OMPI_SUCCESS;
208211

209212
if (module->epoch_type.exposure != NONE_EPOCH) {
@@ -243,12 +246,12 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_wi
243246
uint64_t remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_INDEX_OFFSET; // write to state.post_index on remote side
244247
uint64_t curr_idx = 0, result = 0;
245248

246-
249+
OSC_UCX_GET_DEFAULT_EP(ep, module, ranks_in_win_grp[i]);
247250

248251
/* do fop first to get an post index */
249252
ret = opal_common_ucx_wpmem_fetch(module->state_mem, UCP_ATOMIC_FETCH_OP_FADD,
250253
1, ranks_in_win_grp[i], &result,
251-
sizeof(result), remote_addr);
254+
sizeof(result), remote_addr, ep);
252255

253256
if (ret != OMPI_SUCCESS) {
254257
ret = OMPI_ERROR;
@@ -265,7 +268,7 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_wi
265268
result = myrank + 1;
266269
ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, 0, result,
267270
ranks_in_win_grp[i], &result, sizeof(result),
268-
remote_addr);
271+
remote_addr, ep);
269272

270273
if (ret != OMPI_SUCCESS) {
271274
ret = OMPI_ERROR;

0 commit comments

Comments
 (0)