Skip to content

Commit bfe486f

Browse files
authored
Passing descriptors by reference to CUDA runtime calls (#114288)
Passing a descriptor as a `const Descriptor &` or a `const Descriptor *` generates a FIR signature where the box is passed by value. This is an issue, as it requires a load of the box to be passed. But since, ultimately, all boxes are passed by reference a temporary is generated in LLVM and the reference to the temporary is passed. The boxes addresses are registered with the CUDA runtime but the temporaries are not, thus preventing the runtime to properly map a host side address to its device side counterpart. To address this issue, this PR changes the signatures to the transfer functions to pass a descriptor as a `Descriptor *`, which will in turn generate a FIR signature with that takes a box reference as an argument.
1 parent 84a78ab commit bfe486f

File tree

4 files changed

+23
-34
lines changed

4 files changed

+23
-34
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
3636
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
3737

3838
/// Data transfer from a pointer to a descriptor.
39-
void RTDECL(CUFDataTransferDescPtr)(const Descriptor &dst, void *src,
39+
void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
4040
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
4141
int sourceLine = 0);
4242

4343
/// Data transfer from a descriptor to a pointer.
44-
void RTDECL(CUFDataTransferPtrDesc)(void *dst, const Descriptor &src,
44+
void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
4545
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
4646
int sourceLine = 0);
4747

4848
/// Data transfer from a descriptor to a descriptor.
49-
void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dst,
50-
const Descriptor &src, unsigned mode, const char *sourceFile = nullptr,
51-
int sourceLine = 0);
49+
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
50+
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
5251

5352
} // extern "C"
5453
} // namespace Fortran::runtime::cuda

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ struct CUFDataTransferOpConversion
529529
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
530530
mlir::Value sourceLine =
531531
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
532-
mlir::Value dst = builder.loadIfRef(loc, op.getDst());
533-
mlir::Value src = builder.loadIfRef(loc, op.getSrc());
532+
mlir::Value dst = op.getDst();
533+
mlir::Value src = op.getSrc();
534534
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
535535
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
536536
builder.create<fir::CallOp>(loc, func, args);
@@ -603,11 +603,8 @@ struct CUFDataTransferOpConversion
603603
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
604604
mlir::Value sourceLine =
605605
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
606-
mlir::Value dst =
607-
dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst();
608-
mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy)
609-
? builder.loadIfRef(loc, op.getSrc())
610-
: op.getSrc();
606+
mlir::Value dst = op.getDst();
607+
mlir::Value src = op.getSrc();
611608
llvm::SmallVector<mlir::Value> args{
612609
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
613610
modeValue, sourceFile, sourceLine)};

flang/runtime/CUDA/memory.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,22 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
7373
CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
7474
}
7575

76-
void RTDEF(CUFDataTransferDescPtr)(const Descriptor &desc, void *addr,
76+
void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
7777
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
7878
Terminator terminator{sourceFile, sourceLine};
7979
terminator.Crash(
8080
"not yet implemented: CUDA data transfer from a pointer to a descriptor");
8181
}
8282

83-
void RTDEF(CUFDataTransferPtrDesc)(void *addr, const Descriptor &desc,
83+
void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
8484
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
8585
Terminator terminator{sourceFile, sourceLine};
8686
terminator.Crash(
8787
"not yet implemented: CUDA data transfer from a descriptor to a pointer");
8888
}
8989

90-
void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dstDesc,
91-
const Descriptor &srcDesc, unsigned mode, const char *sourceFile,
92-
int sourceLine) {
90+
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
91+
unsigned mode, const char *sourceFile, int sourceLine) {
9392
Terminator terminator{sourceFile, sourceLine};
9493
terminator.Crash(
9594
"not yet implemented: CUDA data transfer between two descriptors");

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@ func.func @_QPsub1() {
1515
// CHECK-LABEL: func.func @_QPsub1()
1616
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
1717
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
18-
// CHECK: %[[AHOST_LOAD:.*]] = fir.load %[[AHOST]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
19-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
20-
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
21-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
22-
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
18+
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
19+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
20+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
2321

2422
func.func @_QPsub2() {
2523
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub2Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -76,19 +74,17 @@ func.func @_QPsub4() {
7674
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
7775
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
7876
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
79-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
80-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
77+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
8178
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
8279
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
83-
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
80+
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
8481
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
8582
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
8683
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
87-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
8884
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
89-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
85+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
9086
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
91-
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
87+
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
9288

9389
func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
9490
%0 = fir.dummy_scope : !fir.dscope
@@ -122,19 +118,17 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
122118
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
123119
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
124120
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
125-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
126-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
121+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
127122
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
128123
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
129-
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
124+
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
130125
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
131126
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
132127
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
133-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
134128
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
135-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
129+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
136130
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
137-
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
131+
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
138132

139133
func.func @_QPsub6() {
140134
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>

0 commit comments

Comments
 (0)