@@ -78,11 +78,74 @@ struct gathered_data {
78
78
size_t rtype_serialized_length ;
79
79
size_t sbuf_length ;
80
80
size_t rbuf_length ;
81
+ ssize_t sbuf_lb ;
82
+ ssize_t rbuf_lb ;
81
83
void * sbuf ;
82
84
void * rbuf ;
83
85
void * serialization_buffer ;
84
86
};
85
87
88
+ /**
89
+ given a count, displ, and type, compute the true lb and ub for all data used by the count and displ arguments.
90
+ Note the byte at UB is not accessed. so full_true_extent is UB-LB.
91
+
92
+ Consider the most difficult case: a resized type with non-zero LB and
93
+ extent != true_extent, and true_LB != LB. In the figure below:
94
+ X represents the 0-point of the user's `buf`
95
+ x represents data accessed by the type
96
+ - represents data spanned by the type when count > 1
97
+ . represents data not accessed or spanned by the type.
98
+
99
+ + LB = -5
100
+ | + true_LB = -2
101
+ | | + buf (0)
102
+ | | |
103
+ ...---xxXxxx----...
104
+ | |<-->| true_extent = 6
105
+ |<--------->| extent = 13
106
+
107
+ When there are 2 items, the full true extent is
108
+ ...---xxXxxx-------xxxxxx----...
109
+ | |<--------------->| true_extent = 19 ie: extent*(n-1) + true_extent
110
+ |<---------------------->| extent = 26 ie: extent*n
111
+
112
+ */
113
+ static int han_alltoallv_dtype_get_many_true_lb_ub (
114
+ ompi_datatype_t * dtype ,
115
+ ptrdiff_t count ,
116
+ ptrdiff_t displ ,
117
+ ptrdiff_t * full_true_lb ,
118
+ ptrdiff_t * full_true_ub ) {
119
+
120
+ ptrdiff_t extent , true_extent , full_true_extent ;
121
+ ptrdiff_t lb , true_lb ;
122
+ int rc ;
123
+
124
+ /* note: full_true_lb and full_true_ub are undefined when count == 0!
125
+ In this case, set them to 0 and 0. */
126
+ * full_true_lb = 0 ;
127
+ * full_true_ub = 0 ;
128
+ if (count == 0 ) {
129
+ return 0 ;
130
+ }
131
+
132
+ rc = ompi_datatype_get_true_extent ( dtype , & true_lb , & true_extent );
133
+ if (rc ) { return rc ; }
134
+ rc = ompi_datatype_get_extent ( dtype , & lb , & extent );
135
+ if (rc ) { return rc ; }
136
+
137
+ /* extent of data */
138
+ full_true_extent = extent * MAX (0 ,count - 1 ) + true_extent ;
139
+ /* for displ, only extent matters (not true_extent)*/
140
+ ptrdiff_t displ_bytes = displ * extent ;
141
+
142
+ /* now compute full true LB/UB including displ. */
143
+ * full_true_lb = true_lb + displ_bytes ;
144
+ * full_true_ub = * full_true_lb + full_true_extent ;
145
+
146
+ return 0 ;
147
+ }
148
+
86
149
/* Serialize the datatype into the buffer and return buffer length.
87
150
If buf is NULL, just return the length of the required buffer. */
88
151
static size_t ddt_pack_datatype (opal_datatype_t * type , uint8_t * buf )
@@ -142,7 +205,13 @@ static size_t ddt_unpack_datatype(opal_datatype_t* type, uint8_t* buf)
142
205
return length ;
143
206
}
144
207
145
- /* basic implementation: send all buffers without packing keeping a limited number in flight. */
208
+ /* Simple implementation: send all buffers without packing, but still keeping a
209
+ limited number in flight.
210
+
211
+ Note: CMA on XPMEM-mapped buffers does not work. If the low-level network
212
+ provider attempts to use CMA to implement send/recv, then errors will
213
+ occur!
214
+ */
146
215
static inline int alltoallv_sendrecv_w_direct_for_debugging (
147
216
void * * send_from_addrs ,
148
217
size_t * send_counts ,
@@ -185,6 +254,7 @@ static inline int alltoallv_sendrecv_w_direct_for_debugging(
185
254
} else {
186
255
have_completion = 1 ;
187
256
rc = ompi_request_wait_any ( nreqs , requests , & jreq , MPI_STATUS_IGNORE );
257
+ if (rc ) break ;
188
258
}
189
259
int ii_send_req = jreq >= jfirst_sendreq ;
190
260
if (have_completion ) {
@@ -725,28 +795,49 @@ int mca_coll_han_alltoallv_using_smsc(
725
795
opal_datatype_t * peer_recv_types = malloc (sizeof (* peer_recv_types ) * low_size );
726
796
727
797
low_gather_in .serialization_buffer = serialization_buf ;
728
- low_gather_in .sbuf = (void * )sbuf ; // discard const
798
+ low_gather_in .sbuf = (void * )sbuf ; // cast to discard the const
729
799
low_gather_in .rbuf = rbuf ;
730
800
731
- low_gather_in .sbuf_length = 0 ;
732
- low_gather_in .rbuf_length = 0 ;
733
- ptrdiff_t sextent ;
734
- ptrdiff_t rextent ;
735
- rc = ompi_datatype_type_extent ( sdtype , & sextent );
736
- rc = ompi_datatype_type_extent ( rdtype , & rextent );
801
+ ptrdiff_t r_extent , r_lb ;
802
+ rc = ompi_datatype_get_extent ( rdtype , & r_lb , & r_extent );
737
803
738
- /* calculate the extent of our buffers so that peers can mmap the whole thing */
804
+ ssize_t min_send_lb = SSIZE_MAX ;
805
+ ssize_t max_send_ub = - SSIZE_MAX - 1 ;
806
+ ssize_t min_recv_lb = SSIZE_MAX ;
807
+ ssize_t max_recv_ub = - SSIZE_MAX - 1 ;
808
+ /* calculate the maximal bounds of our buffers so that peers can mmap the whole thing. */
739
809
for (int jrankw = 0 ; jrankw < w_size ; jrankw ++ ) {
740
- size_t sz ;
741
- sz = (ompi_disp_array_get (sdispls ,jrankw ) + ompi_count_array_get (scounts ,jrankw ))* sextent ;
742
- if (sz > low_gather_in .sbuf_length ) {
743
- low_gather_in .sbuf_length = sz ;
810
+ ptrdiff_t displ ;
811
+ ptrdiff_t count ;
812
+ ssize_t send_lb , send_ub , recv_lb , recv_ub ;
813
+
814
+ count = ompi_count_array_get (scounts ,jrankw );
815
+ displ = ompi_disp_array_get (sdispls ,jrankw );
816
+ if (count > 0 ) {
817
+ han_alltoallv_dtype_get_many_true_lb_ub ( sdtype , count , displ , & send_lb , & send_ub );
818
+ min_send_lb = MIN ( min_send_lb , send_lb );
819
+ max_send_ub = MAX ( max_send_ub , send_ub );
744
820
}
745
- sz = (ompi_disp_array_get (rdispls ,jrankw ) + ompi_count_array_get (rcounts ,jrankw ))* rextent ;
746
- if (sz > low_gather_in .rbuf_length ) {
747
- low_gather_in .rbuf_length = sz ;
821
+
822
+ count = ompi_count_array_get (rcounts ,jrankw );
823
+ displ = ompi_disp_array_get (rdispls ,jrankw );
824
+ if (count > 0 ) {
825
+ han_alltoallv_dtype_get_many_true_lb_ub ( rdtype , count , displ , & recv_lb , & recv_ub );
826
+ min_recv_lb = MIN ( min_recv_lb , recv_lb );
827
+ max_recv_ub = MAX ( max_recv_ub , recv_ub );
748
828
}
749
829
}
830
+ low_gather_in .sbuf_length = 0 ;
831
+ if (max_send_ub > min_send_lb ) {
832
+ low_gather_in .sbuf_length = max_send_ub - min_send_lb ;
833
+ low_gather_in .sbuf_lb = min_send_lb ;
834
+ }
835
+
836
+ low_gather_in .rbuf_length = 0 ;
837
+ if (max_recv_ub > min_recv_lb ) {
838
+ low_gather_in .rbuf_length = max_recv_ub - min_recv_lb ;
839
+ low_gather_in .rbuf_lb = min_recv_lb ;
840
+ }
750
841
751
842
/* pack the serialization buffer: first the array of counts */
752
843
size_t buf_packed = 0 ;
@@ -782,18 +873,21 @@ int mca_coll_han_alltoallv_using_smsc(
782
873
783
874
*/
784
875
for (int jrank = 0 ; jrank < low_size ; jrank ++ ) {
876
+ void * tmp_ptr ;
877
+ peers [jrank ].map_ctx [0 ] = NULL ;
878
+ peers [jrank ].map_ctx [1 ] = NULL ;
879
+ peers [jrank ].map_ctx [2 ] = NULL ;
880
+
785
881
if (jrank == low_rank ) {
786
882
/* special case for ourself */
787
883
peers [jrank ].counts = (struct peer_counts * )serialization_buf ;
788
884
peers [jrank ].sbuf = sbuf ;
789
885
peers [jrank ].rbuf = rbuf ;
790
886
peers [jrank ].recvtype = & rdtype -> super ;
791
887
peers [jrank ].sendtype = & sdtype -> super ;
792
- peers [jrank ].map_ctx [0 ] = NULL ;
793
- peers [jrank ].map_ctx [1 ] = NULL ;
794
- peers [jrank ].map_ctx [2 ] = NULL ;
795
888
continue ;
796
889
}
890
+
797
891
struct gathered_data * gathered = & low_gather_out [jrank ];
798
892
struct ompi_proc_t * ompi_proc = ompi_comm_peer_lookup (low_comm , jrank );
799
893
mca_smsc_endpoint_t * smsc_ep ;
@@ -812,18 +906,15 @@ int mca_coll_han_alltoallv_using_smsc(
812
906
gathered -> serialization_buffer ,
813
907
peer_serialization_buf_length ,
814
908
(void * * ) & peer_serialization_buf );
815
- peers [jrank ].map_ctx [1 ] = mca_smsc -> map_peer_region (
816
- smsc_ep ,
817
- MCA_RCACHE_FLAGS_PERSIST ,
818
- gathered -> sbuf ,
819
- gathered -> sbuf_length ,
820
- (void * * )& peers [jrank ].sbuf );
821
- peers [jrank ].map_ctx [2 ] = mca_smsc -> map_peer_region (
822
- smsc_ep ,
823
- MCA_RCACHE_FLAGS_PERSIST ,
824
- gathered -> rbuf ,
825
- gathered -> rbuf_length ,
826
- & peers [jrank ].rbuf );
909
+ if (gathered -> sbuf_length > 0 ) {
910
+ peers [jrank ].map_ctx [1 ] = mca_smsc -> map_peer_region (
911
+ smsc_ep ,
912
+ MCA_RCACHE_FLAGS_PERSIST ,
913
+ (char * )gathered -> sbuf + gathered -> sbuf_lb ,
914
+ gathered -> sbuf_length ,
915
+ & tmp_ptr );
916
+ peers [jrank ].sbuf = (char * )tmp_ptr - gathered -> sbuf_lb ;
917
+ }
827
918
828
919
/* point the counts pointer into the mmapped serialization buffer */
829
920
peers [jrank ].counts = (struct peer_counts * )peer_serialization_buf ;
@@ -867,10 +958,10 @@ int mca_coll_han_alltoallv_using_smsc(
867
958
868
959
send_from_addrs [jlow ] = from_addr ;
869
960
send_counts [jlow ] = peers [jlow ].counts [jrank_sendto ].scount ;
870
- send_types [jlow ] = peers [jlow ].sendtype ;
871
- // send_types[jlow] = &(sdtype->super);
961
+ send_types [jlow ] = peers [jlow ].sendtype ;
962
+
872
963
873
- recv_to_addrs [jlow ] = (uint8_t * )rbuf + ompi_disp_array_get (rdispls ,remote_wrank )* rextent ;
964
+ recv_to_addrs [jlow ] = (uint8_t * )rbuf + ompi_disp_array_get (rdispls ,remote_wrank )* r_extent ;
874
965
recv_counts [jlow ] = ompi_count_array_get (rcounts ,remote_wrank );
875
966
recv_types [jlow ] = & (rdtype -> super );
876
967
}
0 commit comments