Skip to content

Commit 8c138be

Browse files
authored
[flang][cuda] Handle pointer allocation with source (#124070)
1 parent 1a8f49f commit 8c138be

File tree

4 files changed

+45
-4
lines changed

4 files changed

+45
-4
lines changed

flang/include/flang/Runtime/CUDA/pointer.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ int RTDECL(CUFPointerAllocate)(Descriptor &, int64_t stream = -1,
2121
bool hasStat = false, const Descriptor *errMsg = nullptr,
2222
const char *sourceFile = nullptr, int sourceLine = 0);
2323

24+
/// Perform allocation of the descriptor without synchronization. Assign data
25+
/// from source.
26+
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
27+
const Descriptor &source, int64_t stream = -1, bool hasStat = false,
28+
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
29+
int sourceLine = 0);
30+
2431
} // extern "C"
2532

2633
} // namespace Fortran::runtime::cuda

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,12 @@ struct CUFAllocateOpConversion
189189

190190
mlir::func::FuncOp func;
191191
if (op.getSource()) {
192-
if (isPointer)
193-
TODO(loc, "pointer allocation with source");
194192
func =
195-
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocateSource)>(
196-
loc, builder);
193+
isPointer
194+
? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
195+
loc, builder)
196+
: fir::runtime::getRuntimeFunc<mkRTKey(
197+
CUFAllocatableAllocateSource)>(loc, builder);
197198
} else {
198199
func =
199200
isPointer

flang/runtime/CUDA/pointer.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Runtime/CUDA/pointer.h"
10+
#include "../assign-impl.h"
1011
#include "../stat.h"
1112
#include "../terminator.h"
13+
#include "flang/Runtime/CUDA/memmove-function.h"
1214
#include "flang/Runtime/pointer.h"
1315

1416
#include "cuda_runtime.h"
@@ -33,6 +35,19 @@ int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t stream, bool hasStat,
3335
return stat;
3436
}
3537

38+
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
39+
const Descriptor &source, int64_t stream, bool hasStat,
40+
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
41+
int stat{RTNAME(CUFPointerAllocate)(
42+
pointer, stream, hasStat, errMsg, sourceFile, sourceLine)};
43+
if (stat == StatOk) {
44+
Terminator terminator{sourceFile, sourceLine};
45+
Fortran::runtime::DoFromSourceAssign(
46+
pointer, source, terminator, &MemmoveHostToDevice);
47+
}
48+
return stat;
49+
}
50+
3651
RT_EXT_API_GROUP_END
3752

3853
} // extern "C"

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,22 @@ func.func @_QPp_alloc() {
192192
// CHECK-LABEL: func.func @_QPp_alloc()
193193
// CHECK: fir.call @_FortranACUFPointerAllocate
194194

195+
func.func @_QPpointer_source() {
196+
%c0_i64 = arith.constant 0 : i64
197+
%c1_i32 = arith.constant 1 : i32
198+
%c0_i32 = arith.constant 0 : i32
199+
%c1 = arith.constant 1 : index
200+
%c0 = arith.constant 0 : index
201+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"}
202+
%4 = fir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
203+
%5 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a_d", data_attr = #cuf.cuda<device>, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
204+
%7 = fir.declare %5 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
205+
%8 = fir.load %4 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
206+
%22 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>> source(%8 : !fir.box<!fir.heap<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
207+
return
208+
}
209+
210+
// CHECK-LABEL: func.func @_QPpointer_source()
211+
// CHECK: _FortranACUFPointerAllocateSource
212+
195213
} // end of module

0 commit comments

Comments
 (0)