@@ -4251,15 +4251,37 @@ class FirConverter : public Fortran::lower::AbstractConverter {
4251
4251
bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs (assign.lhs );
4252
4252
bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs (assign.rhs );
4253
4253
4254
- auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
4254
+ auto getRefFromValue = [](mlir::Value val) -> mlir::Value {
4255
4255
if (auto loadOp =
4256
4256
mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp ()))
4257
4257
return loadOp.getMemref ();
4258
+ if (!mlir::isa<fir::BaseBoxType>(val.getType ()))
4259
+ return val;
4260
+ if (auto declOp =
4261
+ mlir::dyn_cast_or_null<hlfir::DeclareOp>(val.getDefiningOp ())) {
4262
+ if (!declOp.getShape ())
4263
+ return val;
4264
+ if (mlir::isa<fir::ReferenceType>(declOp.getMemref ().getType ()))
4265
+ return declOp.getMemref ();
4266
+ }
4258
4267
return val;
4259
4268
};
4260
4269
4261
- mlir::Value rhsVal = getRefIfLoaded (rhs.getBase ());
4262
- mlir::Value lhsVal = getRefIfLoaded (lhs.getBase ());
4270
+ auto getShapeFromDecl = [](mlir::Value val) -> mlir::Value {
4271
+ if (!mlir::isa<fir::BaseBoxType>(val.getType ()))
4272
+ return {};
4273
+ if (auto declOp =
4274
+ mlir::dyn_cast_or_null<hlfir::DeclareOp>(val.getDefiningOp ()))
4275
+ return declOp.getShape ();
4276
+ return {};
4277
+ };
4278
+
4279
+ mlir::Value rhsVal = getRefFromValue (rhs.getBase ());
4280
+ mlir::Value lhsVal = getRefFromValue (lhs.getBase ());
4281
+ // Get shape from the rhs if available otherwise get it from lhs.
4282
+ mlir::Value shape = getShapeFromDecl (rhs.getBase ());
4283
+ if (!shape)
4284
+ shape = getShapeFromDecl (lhs.getBase ());
4263
4285
4264
4286
// device = host
4265
4287
if (lhsIsDevice && !rhsIsDevice) {
@@ -4272,19 +4294,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
4272
4294
base = convertOp.getValue ();
4273
4295
// Special case if the rhs is a constant.
4274
4296
if (matchPattern (base.getDefiningOp (), mlir::m_Constant ())) {
4275
- builder.create <cuf::DataTransferOp>(
4276
- loc, base, lhsVal, /* shape= */ mlir::Value{}, transferKindAttr);
4297
+ builder.create <cuf::DataTransferOp>(loc, base, lhsVal, shape,
4298
+ transferKindAttr);
4277
4299
} else {
4278
4300
auto associate = hlfir::genAssociateExpr (
4279
4301
loc, builder, rhs, rhs.getType (), " .cuf_host_tmp" );
4280
4302
builder.create <cuf::DataTransferOp>(loc, associate.getBase (), lhsVal,
4281
- /* shape=*/ mlir::Value{},
4282
- transferKindAttr);
4303
+ shape, transferKindAttr);
4283
4304
builder.create <hlfir::EndAssociateOp>(loc, associate);
4284
4305
}
4285
4306
} else {
4286
- builder.create <cuf::DataTransferOp>(
4287
- loc, rhsVal, lhsVal, /* shape= */ mlir::Value{}, transferKindAttr);
4307
+ builder.create <cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
4308
+ transferKindAttr);
4288
4309
}
4289
4310
return ;
4290
4311
}
@@ -4293,8 +4314,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
4293
4314
if (!lhsIsDevice && rhsIsDevice) {
4294
4315
auto transferKindAttr = cuf::DataTransferKindAttr::get (
4295
4316
builder.getContext (), cuf::DataTransferKind::DeviceHost);
4296
- builder.create <cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4297
- /* shape=*/ mlir::Value{},
4317
+ builder.create <cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
4298
4318
transferKindAttr);
4299
4319
return ;
4300
4320
}
@@ -4304,8 +4324,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
4304
4324
assert (rhs.isVariable () && " CUDA Fortran assignment rhs is not legal" );
4305
4325
auto transferKindAttr = cuf::DataTransferKindAttr::get (
4306
4326
builder.getContext (), cuf::DataTransferKind::DeviceDevice);
4307
- builder.create <cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4308
- /* shape=*/ mlir::Value{},
4327
+ builder.create <cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
4309
4328
transferKindAttr);
4310
4329
return ;
4311
4330
}
0 commit comments