Skip to content

Commit 4fafbfe

Browse files
committed
coll/han/alltoallv: Fix for when types have negative LB
Previously this function considered only displ, count, and extent. Since the function uses XPMEM to explicitly expose memory regions, we must also be aware of types that have negative lower bounds and might access data _before_ the user-provided pointer. This change more accurately compute the true upper and lower bounds of all memory accesses, both to insure we don't try to map regions of memory that may not be in our VM page table, and to ensure we expose all memory that will be accessed. Signed-off-by: Luke Robison <[email protected]>
1 parent a50519c commit 4fafbfe

File tree

1 file changed

+125
-34
lines changed

1 file changed

+125
-34
lines changed

ompi/mca/coll/han/coll_han_alltoallv.c

Lines changed: 125 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,74 @@ struct gathered_data {
7878
size_t rtype_serialized_length;
7979
size_t sbuf_length;
8080
size_t rbuf_length;
81+
ssize_t sbuf_lb;
82+
ssize_t rbuf_lb;
8183
void *sbuf;
8284
void *rbuf;
8385
void *serialization_buffer;
8486
};
8587

88+
/**
89+
given a count, displ, and type, compute the true lb and ub for all data used by the count and displ arguments.
90+
Note the byte at UB is not accessed. so full_true_extent is UB-LB.
91+
92+
Consider the most difficult case: a resized type with non-zero LB and
93+
extent != true_extent, and true_LB != LB. In the figure below:
94+
X represents the 0-point of the user's `buf`
95+
x represents data accessed by the type
96+
- represents data spanned by the type when count > 1
97+
. represents data not accessed or spanned by the type.
98+
99+
+ LB = -5
100+
| + true_LB = -2
101+
| | + buf (0)
102+
| | |
103+
...---xxXxxx----...
104+
| |<-->| true_extent = 6
105+
|<--------->| extent = 13
106+
107+
When there are 2 items, the full true extent is
108+
...---xxXxxx-------xxxxxx----...
109+
| |<--------------->| true_extent = 19 ie: extent*(n-1) + true_extent
110+
|<---------------------->| extent = 26 ie: extent*n
111+
112+
*/
113+
static int han_alltoallv_dtype_get_many_true_lb_ub(
114+
ompi_datatype_t *dtype,
115+
ptrdiff_t count,
116+
ptrdiff_t displ,
117+
ptrdiff_t *full_true_lb,
118+
ptrdiff_t *full_true_ub ) {
119+
120+
ptrdiff_t extent, true_extent, full_true_extent;
121+
ptrdiff_t lb, true_lb;
122+
int rc;
123+
124+
/* note: full_true_lb and full_true_ub are undefined when count == 0!
125+
In this case, set them to 0 and 0. */
126+
*full_true_lb = 0;
127+
*full_true_ub = 0;
128+
if (count == 0) {
129+
return 0;
130+
}
131+
132+
rc = ompi_datatype_get_true_extent( dtype, &true_lb, &true_extent);
133+
if (rc) { return rc; }
134+
rc = ompi_datatype_get_extent( dtype, &lb, &extent);
135+
if (rc) { return rc; }
136+
137+
/* extent of data */
138+
full_true_extent = extent*MAX(0,count-1) + true_extent;
139+
/* for displ, only extent matters (not true_extent)*/
140+
ptrdiff_t displ_bytes = displ * extent;
141+
142+
/* now compute full true LB/UB including displ. */
143+
*full_true_lb = true_lb + displ_bytes;
144+
*full_true_ub = *full_true_lb + full_true_extent;
145+
146+
return 0;
147+
}
148+
86149
/* Serialize the datatype into the buffer and return buffer length.
87150
If buf is NULL, just return the length of the required buffer. */
88151
static size_t ddt_pack_datatype(opal_datatype_t* type, uint8_t* buf)
@@ -142,7 +205,13 @@ static size_t ddt_unpack_datatype(opal_datatype_t* type, uint8_t* buf)
142205
return length;
143206
}
144207

145-
/* basic implementation: send all buffers without packing keeping a limited number in flight. */
208+
/* Simple implementation: send all buffers without packing, but still keeping a
209+
limited number in flight.
210+
211+
Note: CMA on XPMEM-mapped buffers does not work. If the low-level network
212+
provider attempts to use CMA to implement send/recv, then errors will
213+
occur!
214+
*/
146215
static inline int alltoallv_sendrecv_w_direct_for_debugging(
147216
void **send_from_addrs,
148217
size_t *send_counts,
@@ -185,6 +254,7 @@ static inline int alltoallv_sendrecv_w_direct_for_debugging(
185254
} else {
186255
have_completion = 1;
187256
rc = ompi_request_wait_any( nreqs, requests, &jreq, MPI_STATUS_IGNORE );
257+
if (rc) break;
188258
}
189259
int ii_send_req = jreq >= jfirst_sendreq;
190260
if (have_completion) {
@@ -725,28 +795,49 @@ int mca_coll_han_alltoallv_using_smsc(
725795
opal_datatype_t *peer_recv_types = malloc(sizeof(*peer_recv_types) * low_size);
726796

727797
low_gather_in.serialization_buffer = serialization_buf;
728-
low_gather_in.sbuf = (void*)sbuf; // discard const
798+
low_gather_in.sbuf = (void*)sbuf; // cast to discard the const
729799
low_gather_in.rbuf = rbuf;
730800

731-
low_gather_in.sbuf_length = 0;
732-
low_gather_in.rbuf_length = 0;
733-
ptrdiff_t sextent;
734-
ptrdiff_t rextent;
735-
rc = ompi_datatype_type_extent( sdtype, &sextent);
736-
rc = ompi_datatype_type_extent( rdtype, &rextent);
801+
ptrdiff_t r_extent, r_lb;
802+
rc = ompi_datatype_get_extent( rdtype, &r_lb, &r_extent);
737803

738-
/* calculate the extent of our buffers so that peers can mmap the whole thing */
804+
ssize_t min_send_lb = SSIZE_MAX;
805+
ssize_t max_send_ub = -SSIZE_MAX-1;
806+
ssize_t min_recv_lb = SSIZE_MAX;
807+
ssize_t max_recv_ub = -SSIZE_MAX-1;
808+
/* calculate the maximal bounds of our buffers so that peers can mmap the whole thing. */
739809
for (int jrankw=0; jrankw<w_size; jrankw++) {
740-
size_t sz;
741-
sz = (ompi_disp_array_get(sdispls,jrankw) + ompi_count_array_get(scounts,jrankw))*sextent;
742-
if (sz > low_gather_in.sbuf_length) {
743-
low_gather_in.sbuf_length = sz;
810+
ptrdiff_t displ;
811+
ptrdiff_t count;
812+
ssize_t send_lb, send_ub, recv_lb, recv_ub;
813+
814+
count = ompi_count_array_get(scounts,jrankw);
815+
displ = ompi_disp_array_get(sdispls,jrankw);
816+
if (count > 0) {
817+
han_alltoallv_dtype_get_many_true_lb_ub( sdtype, count, displ, &send_lb, &send_ub);
818+
min_send_lb = MIN( min_send_lb, send_lb );
819+
max_send_ub = MAX( max_send_ub, send_ub );
744820
}
745-
sz = (ompi_disp_array_get(rdispls,jrankw) + ompi_count_array_get(rcounts,jrankw))*rextent;
746-
if (sz > low_gather_in.rbuf_length) {
747-
low_gather_in.rbuf_length = sz;
821+
822+
count = ompi_count_array_get(rcounts,jrankw);
823+
displ = ompi_disp_array_get(rdispls,jrankw);
824+
if (count > 0) {
825+
han_alltoallv_dtype_get_many_true_lb_ub( rdtype, count, displ, &recv_lb, &recv_ub);
826+
min_recv_lb = MIN( min_recv_lb, recv_lb );
827+
max_recv_ub = MAX( max_recv_ub, recv_ub );
748828
}
749829
}
830+
low_gather_in.sbuf_length = 0;
831+
if (max_send_ub > min_send_lb) {
832+
low_gather_in.sbuf_length = max_send_ub - min_send_lb;
833+
low_gather_in.sbuf_lb = min_send_lb;
834+
}
835+
836+
low_gather_in.rbuf_length = 0;
837+
if (max_recv_ub > min_recv_lb) {
838+
low_gather_in.rbuf_length = max_recv_ub - min_recv_lb;
839+
low_gather_in.rbuf_lb = min_recv_lb;
840+
}
750841

751842
/* pack the serialization buffer: first the array of counts */
752843
size_t buf_packed = 0;
@@ -782,18 +873,21 @@ int mca_coll_han_alltoallv_using_smsc(
782873
783874
*/
784875
for (int jrank=0; jrank<low_size; jrank++) {
876+
void *tmp_ptr;
877+
peers[jrank].map_ctx[0] = NULL;
878+
peers[jrank].map_ctx[1] = NULL;
879+
peers[jrank].map_ctx[2] = NULL;
880+
785881
if (jrank == low_rank) {
786882
/* special case for ourself */
787883
peers[jrank].counts = (struct peer_counts *)serialization_buf;
788884
peers[jrank].sbuf = sbuf;
789885
peers[jrank].rbuf = rbuf;
790886
peers[jrank].recvtype = &rdtype->super;
791887
peers[jrank].sendtype = &sdtype->super;
792-
peers[jrank].map_ctx[0] = NULL;
793-
peers[jrank].map_ctx[1] = NULL;
794-
peers[jrank].map_ctx[2] = NULL;
795888
continue;
796889
}
890+
797891
struct gathered_data *gathered = &low_gather_out[jrank];
798892
struct ompi_proc_t* ompi_proc = ompi_comm_peer_lookup(low_comm, jrank);
799893
mca_smsc_endpoint_t *smsc_ep;
@@ -812,18 +906,15 @@ int mca_coll_han_alltoallv_using_smsc(
812906
gathered->serialization_buffer,
813907
peer_serialization_buf_length,
814908
(void**) &peer_serialization_buf );
815-
peers[jrank].map_ctx[1] = mca_smsc->map_peer_region(
816-
smsc_ep,
817-
MCA_RCACHE_FLAGS_PERSIST,
818-
gathered->sbuf,
819-
gathered->sbuf_length,
820-
(void**)&peers[jrank].sbuf );
821-
peers[jrank].map_ctx[2] = mca_smsc->map_peer_region(
822-
smsc_ep,
823-
MCA_RCACHE_FLAGS_PERSIST,
824-
gathered->rbuf,
825-
gathered->rbuf_length,
826-
&peers[jrank].rbuf );
909+
if (gathered->sbuf_length > 0) {
910+
peers[jrank].map_ctx[1] = mca_smsc->map_peer_region(
911+
smsc_ep,
912+
MCA_RCACHE_FLAGS_PERSIST,
913+
(char*)gathered->sbuf + gathered->sbuf_lb,
914+
gathered->sbuf_length,
915+
&tmp_ptr );
916+
peers[jrank].sbuf = (char*)tmp_ptr - gathered->sbuf_lb;
917+
}
827918

828919
/* point the counts pointer into the mmapped serialization buffer */
829920
peers[jrank].counts = (struct peer_counts*)peer_serialization_buf;
@@ -867,10 +958,10 @@ int mca_coll_han_alltoallv_using_smsc(
867958

868959
send_from_addrs[jlow] = from_addr;
869960
send_counts[jlow] = peers[jlow].counts[jrank_sendto].scount;
870-
send_types[jlow] = peers[jlow].sendtype;
871-
// send_types[jlow] = &(sdtype->super);
961+
send_types[jlow] = peers[jlow].sendtype;
962+
872963

873-
recv_to_addrs[jlow] = (uint8_t*)rbuf + ompi_disp_array_get(rdispls,remote_wrank)*rextent;
964+
recv_to_addrs[jlow] = (uint8_t*)rbuf + ompi_disp_array_get(rdispls,remote_wrank)*r_extent;
874965
recv_counts[jlow] = ompi_count_array_get(rcounts,remote_wrank);
875966
recv_types[jlow] = &(rdtype->super);
876967
}

0 commit comments

Comments
 (0)