|
24 | 24 | /// indirectly via a parent object.
|
25 | 25 | //===----------------------------------------------------------------------===//
|
26 | 26 |
|
| 27 | +#include "flang/Lower/DirectivesCommon.h" |
27 | 28 | #include "flang/Optimizer/Builder/FIRBuilder.h"
|
| 29 | +#include "flang/Optimizer/Builder/HLFIRTools.h" |
28 | 30 | #include "flang/Optimizer/Dialect/FIRType.h"
|
29 | 31 | #include "flang/Optimizer/Dialect/Support/KindMapping.h"
|
| 32 | +#include "flang/Optimizer/HLFIR/HLFIROps.h" |
30 | 33 | #include "flang/Optimizer/OpenMP/Passes.h"
|
| 34 | +#include "mlir/Analysis/SliceAnalysis.h" |
31 | 35 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
32 | 36 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
33 | 37 | #include "mlir/IR/BuiltinDialect.h"
|
@@ -486,6 +490,160 @@ class MapInfoFinalizationPass
|
486 | 490 | // iterations from previous function scopes.
|
487 | 491 | localBoxAllocas.clear();
|
488 | 492 |
|
| 493 | + // First, walk `omp.map.info` ops to see if any record members should be |
| 494 | + // implicitly mapped. |
| 495 | + func->walk([&](mlir::omp::MapInfoOp op) { |
| 496 | + mlir::Type underlyingType = |
| 497 | + fir::unwrapRefType(op.getVarPtr().getType()); |
| 498 | + |
| 499 | + // TODO Test with and support more complicated cases; like arrays for |
| 500 | + // records, for example. |
| 501 | + if (!fir::isRecordWithAllocatableMember(underlyingType)) |
| 502 | + return mlir::WalkResult::advance(); |
| 503 | + |
| 504 | + // TODO For now, only consider `omp.target` ops. Other ops that support |
| 505 | + // `map` clauses will follow later. |
| 506 | + mlir::omp::TargetOp target = |
| 507 | + mlir::dyn_cast_if_present<mlir::omp::TargetOp>( |
| 508 | + getFirstTargetUser(op)); |
| 509 | + |
| 510 | + if (!target) |
| 511 | + return mlir::WalkResult::advance(); |
| 512 | + |
| 513 | + auto mapClauseOwner = |
| 514 | + llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target); |
| 515 | + |
| 516 | + int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op); |
| 517 | + assert(mapVarIdx >= 0 && |
| 518 | + mapVarIdx < |
| 519 | + static_cast<int64_t>(mapClauseOwner.getMapVars().size())); |
| 520 | + |
| 521 | + auto argIface = |
| 522 | + llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target); |
| 523 | + // TODO How should `map` block argument that correspond to: `private`, |
| 524 | + // `use_device_addr`, `use_device_ptr`, be handled? |
| 525 | + mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx]; |
| 526 | + llvm::SetVector<mlir::Operation *> mapVarForwardSlice; |
| 527 | + mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice); |
| 528 | + |
| 529 | + mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) { |
| 530 | + // TODO Support coordinate_of ops. |
| 531 | + // |
| 532 | + // TODO Support call ops by recursively examining the forward slice of |
| 533 | + // the corresponding parameter to the field in the called function. |
| 534 | + return !mlir::isa<hlfir::DesignateOp>(sliceOp); |
| 535 | + }); |
| 536 | + |
| 537 | + auto recordType = mlir::cast<fir::RecordType>(underlyingType); |
| 538 | + llvm::SmallVector<mlir::Value> newMapOpsForFields; |
| 539 | + llvm::SmallVector<int64_t> fieldIndicies; |
| 540 | + |
| 541 | + for (auto fieldMemTyPair : recordType.getTypeList()) { |
| 542 | + auto &field = fieldMemTyPair.first; |
| 543 | + auto memTy = fieldMemTyPair.second; |
| 544 | + |
| 545 | + bool shouldMapField = |
| 546 | + llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) { |
| 547 | + if (!fir::isAllocatableType(memTy)) |
| 548 | + return false; |
| 549 | + |
| 550 | + auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp); |
| 551 | + if (!designateOp) |
| 552 | + return false; |
| 553 | + |
| 554 | + return designateOp.getComponent() && |
| 555 | + designateOp.getComponent()->strref() == field; |
| 556 | + }) != mapVarForwardSlice.end(); |
| 557 | + |
| 558 | + // TODO Handle recursive record types. Adapting |
| 559 | + // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR |
| 560 | + // entities might be helpful here. |
| 561 | + |
| 562 | + if (!shouldMapField) |
| 563 | + continue; |
| 564 | + |
| 565 | + int64_t fieldIdx = recordType.getFieldIndex(field); |
| 566 | + bool alreadyMapped = [&]() { |
| 567 | + if (op.getMembersIndexAttr()) |
| 568 | + for (auto indexList : op.getMembersIndexAttr()) { |
| 569 | + auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList); |
| 570 | + if (indexListAttr.size() == 1 && |
| 571 | + mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() == |
| 572 | + fieldIdx) |
| 573 | + return true; |
| 574 | + } |
| 575 | + |
| 576 | + return false; |
| 577 | + }(); |
| 578 | + |
| 579 | + if (alreadyMapped) |
| 580 | + continue; |
| 581 | + |
| 582 | + builder.setInsertionPoint(op); |
| 583 | + mlir::Value fieldIdxVal = builder.createIntegerConstant( |
| 584 | + op.getLoc(), mlir::IndexType::get(builder.getContext()), |
| 585 | + fieldIdx); |
| 586 | + auto fieldCoord = builder.create<fir::CoordinateOp>( |
| 587 | + op.getLoc(), builder.getRefType(memTy), op.getVarPtr(), |
| 588 | + fieldIdxVal); |
| 589 | + Fortran::lower::AddrAndBoundsInfo info = |
| 590 | + Fortran::lower::getDataOperandBaseAddr( |
| 591 | + builder, fieldCoord, /*isOptional=*/false, op.getLoc()); |
| 592 | + llvm::SmallVector<mlir::Value> bounds = |
| 593 | + Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp, |
| 594 | + mlir::omp::MapBoundsType>( |
| 595 | + builder, info, |
| 596 | + hlfir::translateToExtendedValue(op.getLoc(), builder, |
| 597 | + hlfir::Entity{fieldCoord}) |
| 598 | + .first, |
| 599 | + /*dataExvIsAssumedSize=*/false, op.getLoc()); |
| 600 | + |
| 601 | + mlir::omp::MapInfoOp fieldMapOp = |
| 602 | + builder.create<mlir::omp::MapInfoOp>( |
| 603 | + op.getLoc(), fieldCoord.getResult().getType(), |
| 604 | + fieldCoord.getResult(), |
| 605 | + mlir::TypeAttr::get( |
| 606 | + fir::unwrapRefType(fieldCoord.getResult().getType())), |
| 607 | + /*varPtrPtr=*/mlir::Value{}, |
| 608 | + /*members=*/mlir::ValueRange{}, |
| 609 | + /*members_index=*/mlir::ArrayAttr{}, |
| 610 | + /*bounds=*/bounds, op.getMapTypeAttr(), |
| 611 | + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( |
| 612 | + mlir::omp::VariableCaptureKind::ByRef), |
| 613 | + builder.getStringAttr(op.getNameAttr().strref() + "." + |
| 614 | + field + ".implicit_map"), |
| 615 | + /*partial_map=*/builder.getBoolAttr(false)); |
| 616 | + newMapOpsForFields.emplace_back(fieldMapOp); |
| 617 | + fieldIndicies.emplace_back(fieldIdx); |
| 618 | + } |
| 619 | + |
| 620 | + if (newMapOpsForFields.empty()) |
| 621 | + return mlir::WalkResult::advance(); |
| 622 | + |
| 623 | + op.getMembersMutable().append(newMapOpsForFields); |
| 624 | + llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices; |
| 625 | + mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr(); |
| 626 | + |
| 627 | + if (oldMembersIdxAttr) |
| 628 | + for (mlir::Attribute indexList : oldMembersIdxAttr) { |
| 629 | + llvm::SmallVector<int64_t> listVec; |
| 630 | + |
| 631 | + for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList)) |
| 632 | + listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt()); |
| 633 | + |
| 634 | + newMemberIndices.emplace_back(std::move(listVec)); |
| 635 | + } |
| 636 | + |
| 637 | + for (int64_t newFieldIdx : fieldIndicies) |
| 638 | + newMemberIndices.emplace_back( |
| 639 | + llvm::SmallVector<int64_t>(1, newFieldIdx)); |
| 640 | + |
| 641 | + op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices)); |
| 642 | + op.setPartialMap(true); |
| 643 | + |
| 644 | + return mlir::WalkResult::advance(); |
| 645 | + }); |
| 646 | + |
489 | 647 | func->walk([&](mlir::omp::MapInfoOp op) {
|
490 | 648 | // TODO: Currently only supports a single user for the MapInfoOp. This
|
491 | 649 | // is fine for the moment, as the Fortran frontend will generate a
|
|
0 commit comments