Skip to content

Commit 4da9d91

Browse files
authored
Merge pull request #8536 from wckzhang/gdr
Add CUDA support for the OFI MTL
2 parents 57026cf + deb37ac commit 4da9d91

30 files changed

+745
-316
lines changed

config/opal_check_cuda.m4

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ AC_MSG_CHECKING([if have cuda support])
117117
if test "$opal_check_cuda_happy" = "yes"; then
118118
AC_MSG_RESULT([yes (-I$opal_cuda_incdir)])
119119
CUDA_SUPPORT=1
120-
opal_datatype_cuda_CPPFLAGS="-I$opal_cuda_incdir"
121-
AC_SUBST([opal_datatype_cuda_CPPFLAGS])
120+
common_cuda_CPPFLAGS="-I$opal_cuda_incdir"
121+
AC_SUBST([common_cuda_CPPFLAGS])
122122
else
123123
AC_MSG_RESULT([no])
124124
CUDA_SUPPORT=0

ompi/mca/coll/cuda/coll_cuda_allreduce.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include "ompi/op/op.h"
1919
#include "opal/datatype/opal_convertor.h"
20-
#include "opal/datatype/opal_datatype_cuda.h"
20+
#include "opal/mca/common/cuda/common_cuda.h"
2121

2222
/*
2323
* allreduce_intra

ompi/mca/coll/cuda/coll_cuda_exscan.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include "ompi/op/op.h"
1919
#include "opal/datatype/opal_convertor.h"
20-
#include "opal/datatype/opal_datatype_cuda.h"
20+
#include "opal/mca/common/cuda/common_cuda.h"
2121

2222
int mca_coll_cuda_exscan(const void *sbuf, void *rbuf, int count,
2323
struct ompi_datatype_t *dtype,

ompi/mca/coll/cuda/coll_cuda_reduce.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include "ompi/op/op.h"
1919
#include "opal/datatype/opal_convertor.h"
20-
#include "opal/datatype/opal_datatype_cuda.h"
20+
#include "opal/mca/common/cuda/common_cuda.h"
2121

2222
/*
2323
* reduce_log_inter

ompi/mca/coll/cuda/coll_cuda_reduce_scatter_block.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include "ompi/op/op.h"
1919
#include "opal/datatype/opal_convertor.h"
20-
#include "opal/datatype/opal_datatype_cuda.h"
20+
#include "opal/mca/common/cuda/common_cuda.h"
2121

2222
/*
2323
* reduce_scatter_block

ompi/mca/coll/cuda/coll_cuda_scan.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include "ompi/op/op.h"
1919
#include "opal/datatype/opal_convertor.h"
20-
#include "opal/datatype/opal_datatype_cuda.h"
20+
#include "opal/mca/common/cuda/common_cuda.h"
2121

2222
/*
2323
* scan

ompi/mca/coll/libnbc/nbc_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#include "coll_libnbc.h"
3232
#if OPAL_CUDA_SUPPORT
3333
#include "opal/datatype/opal_convertor.h"
34-
#include "opal/datatype/opal_datatype_cuda.h"
34+
#include "opal/mca/common/cuda/common_cuda.h"
3535
#endif /* OPAL_CUDA_SUPPORT */
3636
#include "ompi/include/ompi/constants.h"
3737
#include "ompi/request/request.h"

ompi/mca/common/ompio/common_ompio_buffer.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "ompi_config.h"
2121

2222
#include "opal/datatype/opal_convertor.h"
23-
#include "opal/datatype/opal_datatype_cuda.h"
2423
#include "opal/mca/common/cuda/common_cuda.h"
2524
#include "opal/util/sys_limits.h"
2625

ompi/mca/mtl/base/mtl_base_datatype.h

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,82 @@
2525
#include "ompi/datatype/ompi_datatype.h"
2626
#include "opal/datatype/opal_convertor.h"
2727
#include "opal/datatype/opal_datatype_internal.h"
28+
#if OPAL_CUDA_SUPPORT
29+
#include "opal/mca/common/cuda/common_cuda.h"
30+
#include "opal/datatype/opal_convertor.h"
31+
#endif
2832

2933
#ifndef MTL_BASE_DATATYPE_H_INCLUDED
3034
#define MTL_BASE_DATATYPE_H_INCLUDED
3135

36+
#if OPAL_CUDA_SUPPORT
37+
static int
38+
ompi_mtl_cuda_datatype_pack(struct opal_convertor_t *convertor,
39+
void **buffer,
40+
size_t *buffer_len,
41+
bool *freeAfter)
42+
{
43+
44+
struct iovec iov;
45+
uint32_t iov_count = 1;
46+
int is_cuda = convertor->flags & CONVERTOR_CUDA;
47+
48+
#if !(OPAL_ENABLE_HETEROGENEOUS_SUPPORT)
49+
if (convertor->pDesc &&
50+
!(convertor->flags & CONVERTOR_COMPLETED) &&
51+
opal_datatype_is_contiguous_memory_layout(convertor->pDesc,
52+
convertor->count)) {
53+
*freeAfter = false;
54+
*buffer = convertor->pBaseBuf;
55+
*buffer_len = convertor->local_size;
56+
return OPAL_SUCCESS;
57+
}
58+
#endif
59+
60+
opal_convertor_get_packed_size(convertor, buffer_len);
61+
*freeAfter = false;
62+
if( 0 == *buffer_len ) {
63+
*buffer = NULL;
64+
return OMPI_SUCCESS;
65+
}
66+
iov.iov_len = *buffer_len;
67+
iov.iov_base = NULL;
68+
/* opal_convertor_need_buffers always returns true
69+
* if CONVERTOR_CUDA is set, so unset temporarily
70+
*/
71+
convertor->flags &= ~CONVERTOR_CUDA;
72+
73+
if (opal_convertor_need_buffers(convertor)) {
74+
if (is_cuda) {
75+
convertor->flags |= CONVERTOR_CUDA;
76+
}
77+
iov.iov_base = opal_cuda_malloc(*buffer_len, convertor);
78+
if (NULL == iov.iov_base) return OMPI_ERR_OUT_OF_RESOURCE;
79+
*freeAfter = true;
80+
} else if (is_cuda) {
81+
convertor->flags |= CONVERTOR_CUDA;
82+
}
83+
84+
opal_convertor_pack( convertor, &iov, &iov_count, buffer_len );
85+
86+
*buffer = iov.iov_base;
87+
88+
return OMPI_SUCCESS;
89+
}
90+
#endif
91+
3292
__opal_attribute_always_inline__ static inline int
3393
ompi_mtl_datatype_pack(struct opal_convertor_t *convertor,
3494
void **buffer,
3595
size_t *buffer_len,
3696
bool *freeAfter)
3797
{
98+
#if OPAL_CUDA_SUPPORT
99+
return ompi_mtl_cuda_datatype_pack(convertor,
100+
buffer,
101+
buffer_len,
102+
freeAfter);
103+
#endif
38104
struct iovec iov;
39105
uint32_t iov_count = 1;
40106

@@ -71,13 +137,56 @@ ompi_mtl_datatype_pack(struct opal_convertor_t *convertor,
71137
return OMPI_SUCCESS;
72138
}
73139

140+
#if OPAL_CUDA_SUPPORT
141+
static int
142+
ompi_mtl_cuda_datatype_recv_buf(struct opal_convertor_t *convertor,
143+
void ** buffer,
144+
size_t *buffer_len,
145+
bool *free_on_error)
146+
{
147+
int is_cuda = convertor->flags & CONVERTOR_CUDA;
148+
opal_convertor_get_packed_size(convertor, buffer_len);
149+
*free_on_error = false;
150+
if( 0 == *buffer_len ) {
151+
*buffer = NULL;
152+
*buffer_len = 0;
153+
return OMPI_SUCCESS;
154+
}
155+
/* opal_convertor_need_buffers always returns true
156+
* if CONVERTOR_CUDA is set, so unset temporarily
157+
*/
158+
convertor->flags &= ~CONVERTOR_CUDA;
159+
if (opal_convertor_need_buffers(convertor)) {
160+
if (is_cuda) {
161+
convertor->flags |= CONVERTOR_CUDA;
162+
}
163+
*buffer = opal_cuda_malloc(*buffer_len, convertor);
164+
*free_on_error = true;
165+
} else {
166+
if (is_cuda) {
167+
convertor->flags |= CONVERTOR_CUDA;
168+
}
169+
*buffer = convertor->pBaseBuf +
170+
convertor->use_desc->desc[convertor->use_desc->used].end_loop.first_elem_disp;
171+
}
172+
return OMPI_SUCCESS;
173+
174+
}
175+
#endif
74176

75177
__opal_attribute_always_inline__ static inline int
76178
ompi_mtl_datatype_recv_buf(struct opal_convertor_t *convertor,
77179
void ** buffer,
78180
size_t *buffer_len,
79181
bool *free_on_error)
80182
{
183+
#if OPAL_CUDA_SUPPORT
184+
return ompi_mtl_cuda_datatype_recv_buf(convertor,
185+
buffer,
186+
buffer_len,
187+
free_on_error);
188+
#endif
189+
81190
opal_convertor_get_packed_size(convertor, buffer_len);
82191
*free_on_error = false;
83192
if( 0 == *buffer_len ) {
@@ -95,12 +204,48 @@ ompi_mtl_datatype_recv_buf(struct opal_convertor_t *convertor,
95204
return OMPI_SUCCESS;
96205
}
97206

207+
#if OPAL_CUDA_SUPPORT
208+
static int
209+
ompi_mtl_cuda_datatype_unpack(struct opal_convertor_t *convertor,
210+
void *buffer,
211+
size_t buffer_len) {
212+
struct iovec iov;
213+
uint32_t iov_count = 1;
214+
int is_cuda = convertor->flags & CONVERTOR_CUDA;
215+
/* opal_convertor_need_buffers always returns true
216+
* if CONVERTOR_CUDA is set, so unset temporarily
217+
*/
218+
convertor->flags &= ~CONVERTOR_CUDA;
219+
220+
if (buffer_len > 0 && opal_convertor_need_buffers(convertor)) {
221+
iov.iov_len = buffer_len;
222+
iov.iov_base = buffer;
223+
224+
if (is_cuda) {
225+
convertor->flags |= CONVERTOR_CUDA;
226+
}
227+
opal_convertor_unpack(convertor, &iov, &iov_count, &buffer_len );
228+
229+
opal_cuda_free(buffer, convertor);
230+
} else if (is_cuda) {
231+
convertor->flags |= CONVERTOR_CUDA;
232+
}
233+
234+
return OMPI_SUCCESS;
235+
236+
}
237+
#endif
98238

99239
__opal_attribute_always_inline__ static inline int
100240
ompi_mtl_datatype_unpack(struct opal_convertor_t *convertor,
101241
void *buffer,
102242
size_t buffer_len)
103243
{
244+
#if OPAL_CUDA_SUPPORT
245+
return ompi_mtl_cuda_datatype_unpack(convertor,
246+
buffer,
247+
buffer_len);
248+
#endif
104249
struct iovec iov;
105250
uint32_t iov_count = 1;
106251

ompi/mca/mtl/ofi/configure.m4

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ AC_DEFUN([MCA_ompi_mtl_ofi_CONFIG],[
2828
# Check for OFI
2929
OPAL_CHECK_OFI
3030

31+
# Check for CUDA
32+
OPAL_CHECK_CUDA
33+
34+
# Check for cuda support. If so, we require a minimum libfabric version
35+
# of 1.9. FI_HMEM capabilities are only available starting from v1.9
36+
opal_ofi_happy="yes"
37+
AS_IF([test "$opal_check_cuda_happy" = "yes"],
38+
[OPAL_CHECK_OFI_VERSION_GE([1,9],
39+
[],
40+
[opal_ofi_happy=no])])
41+
3142
# The OFI MTL requires at least OFI libfabric v1.5.
3243
AS_IF([test "$opal_ofi_happy" = "yes"],
3344
[OPAL_CHECK_OFI_VERSION_GE([1,5],

ompi/mca/mtl/ofi/help-mtl-ofi.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,12 @@ recoverable and your application is likely to abort.
7777
Error: %s (%d)
7878
[message too big]
7979
Message size %llu bigger than supported by selected transport. Max = %llu
80+
81+
[Buffer Memory Registration Failed]
82+
Open MPI failed to register your buffer.
83+
This error is fatal, your job will abort
84+
85+
Buffer Type: %s
86+
Buffer Address: %p
87+
Buffer Length: %d
88+
Error: %s (%zd)

0 commit comments

Comments
 (0)