Skip to content

Commit 900cd62

Browse files
authored
[flang][cuda] Simplify data transfer when possible (#106120)
When possible, avoid using descriptors and use the reference and the shape for data_transfer.
1 parent 384d69f commit 900cd62

File tree

3 files changed

+52
-16
lines changed

3 files changed

+52
-16
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4251,15 +4251,37 @@ class FirConverter : public Fortran::lower::AbstractConverter {
42514251
bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs);
42524252
bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs);
42534253

4254-
auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
4254+
auto getRefFromValue = [](mlir::Value val) -> mlir::Value {
42554255
if (auto loadOp =
42564256
mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
42574257
return loadOp.getMemref();
4258+
if (!mlir::isa<fir::BaseBoxType>(val.getType()))
4259+
return val;
4260+
if (auto declOp =
4261+
mlir::dyn_cast_or_null<hlfir::DeclareOp>(val.getDefiningOp())) {
4262+
if (!declOp.getShape())
4263+
return val;
4264+
if (mlir::isa<fir::ReferenceType>(declOp.getMemref().getType()))
4265+
return declOp.getMemref();
4266+
}
42584267
return val;
42594268
};
42604269

4261-
mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
4262-
mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
4270+
auto getShapeFromDecl = [](mlir::Value val) -> mlir::Value {
4271+
if (!mlir::isa<fir::BaseBoxType>(val.getType()))
4272+
return {};
4273+
if (auto declOp =
4274+
mlir::dyn_cast_or_null<hlfir::DeclareOp>(val.getDefiningOp()))
4275+
return declOp.getShape();
4276+
return {};
4277+
};
4278+
4279+
mlir::Value rhsVal = getRefFromValue(rhs.getBase());
4280+
mlir::Value lhsVal = getRefFromValue(lhs.getBase());
4281+
// Get shape from the rhs if available otherwise get it from lhs.
4282+
mlir::Value shape = getShapeFromDecl(rhs.getBase());
4283+
if (!shape)
4284+
shape = getShapeFromDecl(lhs.getBase());
42634285

42644286
// device = host
42654287
if (lhsIsDevice && !rhsIsDevice) {
@@ -4272,19 +4294,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
42724294
base = convertOp.getValue();
42734295
// Special case if the rhs is a constant.
42744296
if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
4275-
builder.create<cuf::DataTransferOp>(
4276-
loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
4297+
builder.create<cuf::DataTransferOp>(loc, base, lhsVal, shape,
4298+
transferKindAttr);
42774299
} else {
42784300
auto associate = hlfir::genAssociateExpr(
42794301
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
42804302
builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
4281-
/*shape=*/mlir::Value{},
4282-
transferKindAttr);
4303+
shape, transferKindAttr);
42834304
builder.create<hlfir::EndAssociateOp>(loc, associate);
42844305
}
42854306
} else {
4286-
builder.create<cuf::DataTransferOp>(
4287-
loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
4307+
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
4308+
transferKindAttr);
42884309
}
42894310
return;
42904311
}
@@ -4293,8 +4314,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
42934314
if (!lhsIsDevice && rhsIsDevice) {
42944315
auto transferKindAttr = cuf::DataTransferKindAttr::get(
42954316
builder.getContext(), cuf::DataTransferKind::DeviceHost);
4296-
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4297-
/*shape=*/mlir::Value{},
4317+
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
42984318
transferKindAttr);
42994319
return;
43004320
}
@@ -4304,8 +4324,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
43044324
assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
43054325
auto transferKindAttr = cuf::DataTransferKindAttr::get(
43064326
builder.getContext(), cuf::DataTransferKind::DeviceDevice);
4307-
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4308-
/*shape=*/mlir::Value{},
4327+
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
43094328
transferKindAttr);
43104329
return;
43114330
}

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,11 @@ llvm::LogicalResult cuf::DataTransferOp::verify() {
112112
if (fir::isa_trivial(srcTy) &&
113113
matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
114114
return mlir::success();
115+
115116
return emitOpError()
116117
<< "expect src and dst to be references or descriptors or src to "
117-
"be a constant";
118+
"be a constant: "
119+
<< srcTy << " - " << dstTy;
118120
}
119121

120122
//===----------------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ contains
1111
function dev1(a)
1212
integer, device :: a(:)
1313
integer :: dev1
14+
dev1 = 1
1415
end function
1516
end
1617

@@ -198,8 +199,8 @@ end subroutine
198199
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<i32> {fir.bindc_name = "n"})
199200
! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]](%{{.*}}) dummy_scope %{{.*}} {uniq_name = "_QFsub8Eb"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
200201
! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub8Ea"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?xi32>>)
201-
! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<10xi32>>
202-
! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.box<!fir.array<?xi32>>
202+
! CHECK: cuf.data_transfer %[[ARG0]] to %[[B]]#0, %{{.*}} : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<?xi32>>, !fir.ref<!fir.array<10xi32>>
203+
! CHECK: cuf.data_transfer %[[B]]#0 to %[[ARG0]], %{{.*}} : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<?xi32>>
203204

204205
subroutine sub9(a)
205206
integer, pinned, allocatable :: a(:)
@@ -274,3 +275,17 @@ end subroutine
274275
! CHECK-LABEL: func.func @_QPsub14()
275276
! CHECK: %[[TRUE:.*]] = arith.constant true
276277
! CHECK: cuf.data_transfer %[[TRUE]] to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i1, !fir.ref<!fir.array<10x!fir.logical<4>>>
278+
279+
subroutine sub15(a_dev, a_host, n, m)
280+
integer, intent(in) :: n, m
281+
real, device :: a_dev(n*m)
282+
real :: a_host(n*m)
283+
284+
a_dev = a_host
285+
end subroutine
286+
287+
! CHECK-LABEL: func.func @_QPsub15(
288+
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a_dev"}, %[[ARG1:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a_host"}
289+
! CHECK: %{{.*}} = fir.shape %{{.*}} : (index) -> !fir.shape<1>
290+
! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
291+
! CHECK: cuf.data_transfer %[[ARG1]] to %[[ARG0]], %[[SHAPE]] : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>

0 commit comments

Comments
 (0)