25
25
#include "ompi/datatype/ompi_datatype.h"
26
26
#include "opal/datatype/opal_convertor.h"
27
27
#include "opal/datatype/opal_datatype_internal.h"
28
+ #if OPAL_CUDA_SUPPORT
29
+ #include "opal/mca/common/cuda/common_cuda.h"
30
+ #include "opal/datatype/opal_convertor.h"
31
+ #include "opal/datatype/opal_datatype_cuda.h"
32
+ #endif
28
33
29
34
#ifndef MTL_BASE_DATATYPE_H_INCLUDED
30
35
#define MTL_BASE_DATATYPE_H_INCLUDED
31
36
37
+ #if OPAL_CUDA_SUPPORT
38
+ static int
39
+ ompi_mtl_cuda_datatype_pack (struct opal_convertor_t * convertor ,
40
+ void * * buffer ,
41
+ size_t * buffer_len ,
42
+ bool * freeAfter )
43
+ {
44
+
45
+ struct iovec iov ;
46
+ uint32_t iov_count = 1 ;
47
+ int is_cuda = convertor -> flags & CONVERTOR_CUDA ;
48
+
49
+ #if !(OPAL_ENABLE_HETEROGENEOUS_SUPPORT )
50
+ if (convertor -> pDesc &&
51
+ !(convertor -> flags & CONVERTOR_COMPLETED ) &&
52
+ opal_datatype_is_contiguous_memory_layout (convertor -> pDesc ,
53
+ convertor -> count )) {
54
+ * freeAfter = false;
55
+ * buffer = convertor -> pBaseBuf ;
56
+ * buffer_len = convertor -> local_size ;
57
+ return OPAL_SUCCESS ;
58
+ }
59
+ #endif
60
+
61
+ opal_convertor_get_packed_size (convertor , buffer_len );
62
+ * freeAfter = false;
63
+ if ( 0 == * buffer_len ) {
64
+ * buffer = NULL ;
65
+ return OMPI_SUCCESS ;
66
+ }
67
+ iov .iov_len = * buffer_len ;
68
+ iov .iov_base = NULL ;
69
+ /* opal_convertor_need_buffers always returns true
70
+ * if CONVERTOR_CUDA is set, so unset temporarily
71
+ */
72
+ convertor -> flags &= ~CONVERTOR_CUDA ;
73
+
74
+ if (opal_convertor_need_buffers (convertor )) {
75
+ if (is_cuda ) {
76
+ convertor -> flags |= CONVERTOR_CUDA ;
77
+ }
78
+ iov .iov_base = opal_cuda_malloc (* buffer_len , convertor );
79
+ if (NULL == iov .iov_base ) return OMPI_ERR_OUT_OF_RESOURCE ;
80
+ * freeAfter = true;
81
+ } else if (is_cuda ) {
82
+ convertor -> flags |= CONVERTOR_CUDA ;
83
+ }
84
+
85
+ opal_convertor_pack ( convertor , & iov , & iov_count , buffer_len );
86
+
87
+ * buffer = iov .iov_base ;
88
+
89
+ return OMPI_SUCCESS ;
90
+ }
91
+ #endif
92
+
32
93
__opal_attribute_always_inline__ static inline int
33
94
ompi_mtl_datatype_pack (struct opal_convertor_t * convertor ,
34
95
void * * buffer ,
35
96
size_t * buffer_len ,
36
97
bool * freeAfter )
37
98
{
99
+ #if OPAL_CUDA_SUPPORT
100
+ return ompi_mtl_cuda_datatype_pack (convertor ,
101
+ buffer ,
102
+ buffer_len ,
103
+ freeAfter );
104
+ #endif
38
105
struct iovec iov ;
39
106
uint32_t iov_count = 1 ;
40
107
@@ -71,13 +138,56 @@ ompi_mtl_datatype_pack(struct opal_convertor_t *convertor,
71
138
return OMPI_SUCCESS ;
72
139
}
73
140
141
+ #if OPAL_CUDA_SUPPORT
142
+ static int
143
+ ompi_mtl_cuda_datatype_recv_buf (struct opal_convertor_t * convertor ,
144
+ void * * buffer ,
145
+ size_t * buffer_len ,
146
+ bool * free_on_error )
147
+ {
148
+ int is_cuda = convertor -> flags & CONVERTOR_CUDA ;
149
+ opal_convertor_get_packed_size (convertor , buffer_len );
150
+ * free_on_error = false;
151
+ if ( 0 == * buffer_len ) {
152
+ * buffer = NULL ;
153
+ * buffer_len = 0 ;
154
+ return OMPI_SUCCESS ;
155
+ }
156
+ /* opal_convertor_need_buffers always returns true
157
+ * if CONVERTOR_CUDA is set, so unset temporarily
158
+ */
159
+ convertor -> flags &= ~CONVERTOR_CUDA ;
160
+ if (opal_convertor_need_buffers (convertor )) {
161
+ if (is_cuda ) {
162
+ convertor -> flags |= CONVERTOR_CUDA ;
163
+ }
164
+ * buffer = opal_cuda_malloc (* buffer_len , convertor );
165
+ * free_on_error = true;
166
+ } else {
167
+ if (is_cuda ) {
168
+ convertor -> flags |= CONVERTOR_CUDA ;
169
+ }
170
+ * buffer = convertor -> pBaseBuf +
171
+ convertor -> use_desc -> desc [convertor -> use_desc -> used ].end_loop .first_elem_disp ;
172
+ }
173
+ return OMPI_SUCCESS ;
174
+
175
+ }
176
+ #endif
74
177
75
178
__opal_attribute_always_inline__ static inline int
76
179
ompi_mtl_datatype_recv_buf (struct opal_convertor_t * convertor ,
77
180
void * * buffer ,
78
181
size_t * buffer_len ,
79
182
bool * free_on_error )
80
183
{
184
+ #if OPAL_CUDA_SUPPORT
185
+ return ompi_mtl_cuda_datatype_recv_buf (convertor ,
186
+ buffer ,
187
+ buffer_len ,
188
+ free_on_error );
189
+ #endif
190
+
81
191
opal_convertor_get_packed_size (convertor , buffer_len );
82
192
* free_on_error = false;
83
193
if ( 0 == * buffer_len ) {
@@ -95,12 +205,48 @@ ompi_mtl_datatype_recv_buf(struct opal_convertor_t *convertor,
95
205
return OMPI_SUCCESS ;
96
206
}
97
207
208
+ #if OPAL_CUDA_SUPPORT
209
+ static int
210
+ ompi_mtl_cuda_datatype_unpack (struct opal_convertor_t * convertor ,
211
+ void * buffer ,
212
+ size_t buffer_len ) {
213
+ struct iovec iov ;
214
+ uint32_t iov_count = 1 ;
215
+ int is_cuda = convertor -> flags & CONVERTOR_CUDA ;
216
+ /* opal_convertor_need_buffers always returns true
217
+ * if CONVERTOR_CUDA is set, so unset temporarily
218
+ */
219
+ convertor -> flags &= ~CONVERTOR_CUDA ;
220
+
221
+ if (buffer_len > 0 && opal_convertor_need_buffers (convertor )) {
222
+ iov .iov_len = buffer_len ;
223
+ iov .iov_base = buffer ;
224
+
225
+ if (is_cuda ) {
226
+ convertor -> flags |= CONVERTOR_CUDA ;
227
+ }
228
+ opal_convertor_unpack (convertor , & iov , & iov_count , & buffer_len );
229
+
230
+ opal_cuda_free (buffer , convertor );
231
+ } else if (is_cuda ) {
232
+ convertor -> flags |= CONVERTOR_CUDA ;
233
+ }
234
+
235
+ return OMPI_SUCCESS ;
236
+
237
+ }
238
+ #endif
98
239
99
240
__opal_attribute_always_inline__ static inline int
100
241
ompi_mtl_datatype_unpack (struct opal_convertor_t * convertor ,
101
242
void * buffer ,
102
243
size_t buffer_len )
103
244
{
245
+ #if OPAL_CUDA_SUPPORT
246
+ return ompi_mtl_cuda_datatype_unpack (convertor ,
247
+ buffer ,
248
+ buffer_len );
249
+ #endif
104
250
struct iovec iov ;
105
251
uint32_t iov_count = 1 ;
106
252
0 commit comments