Skip to content

Commit cea71dc

Browse files
authored
Merge pull request #931 from kiranchandramohan/matmul
Lowering for the Matmul Intrinsic
2 parents b209722 + dbe480d commit cea71dc

File tree

4 files changed

+119
-1
lines changed

4 files changed

+119
-1
lines changed

flang/include/flang/Lower/TransformationalRuntime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ void genCshiftVector(FirOpBuilder &builder, mlir::Location loc,
2626
mlir::Value resultBox, mlir::Value arrayBox,
2727
mlir::Value shiftBox);
2828

29+
void genMatmul(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
30+
mlir::Value matrixABox, mlir::Value matrixBBox,
31+
mlir::Value resultBox);
32+
2933
void genReshape(FirOpBuilder &builder, mlir::Location loc,
3034
mlir::Value resultBox, mlir::Value sourceBox,
3135
mlir::Value shapeBox, mlir::Value padBox, mlir::Value orderBox);

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ struct IntrinsicLibrary {
482482
mlir::Value genIshftc(mlir::Type, llvm::ArrayRef<mlir::Value>);
483483
fir::ExtendedValue genLen(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
484484
fir::ExtendedValue genLenTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
485+
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
485486
fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
486487
fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
487488
mlir::Value genMerge(mlir::Type, llvm::ArrayRef<mlir::Value>);
@@ -703,6 +704,10 @@ static constexpr IntrinsicHandler handlers[]{
703704
{"lgt", &I::genCharacterCompare<mlir::CmpIPredicate::sgt>},
704705
{"lle", &I::genCharacterCompare<mlir::CmpIPredicate::sle>},
705706
{"llt", &I::genCharacterCompare<mlir::CmpIPredicate::slt>},
707+
{"matmul",
708+
&I::genMatmul,
709+
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
710+
/*isElemental=*/false},
706711
{"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
707712
{"maxloc",
708713
&I::genMaxloc,
@@ -2427,6 +2432,34 @@ IntrinsicLibrary::genCharacterCompare(mlir::Type type,
24272432
fir::getBase(args[1]), fir::getLen(args[1]));
24282433
}
24292434

2435+
// MATMUL
2436+
fir::ExtendedValue
2437+
IntrinsicLibrary::genMatmul(mlir::Type resultType,
2438+
llvm::ArrayRef<fir::ExtendedValue> args) {
2439+
assert(args.size() == 2);
2440+
2441+
// Handle required matmul arguments
2442+
fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
2443+
mlir::Value matrixA = fir::getBase(matrixTmpA);
2444+
fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
2445+
mlir::Value matrixB = fir::getBase(matrixTmpB);
2446+
unsigned resultRank =
2447+
(matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
2448+
2449+
// Create mutable fir.box to be passed to the runtime for the result.
2450+
auto resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
2451+
auto resultMutableBox =
2452+
Fortran::lower::createTempMutableBox(builder, loc, resultArrayType);
2453+
auto resultIrBox =
2454+
Fortran::lower::getMutableIRBox(builder, loc, resultMutableBox);
2455+
// Call runtime. The runtime is allocating the result.
2456+
Fortran::lower::genMatmul(builder, loc, resultIrBox, matrixA, matrixB);
2457+
// Read result from mutable fir.box and add it to the list of temps to be
2458+
// finalized by the StatementContext.
2459+
return readAndAddCleanUp(resultMutableBox, resultType,
2460+
"unexpected result for MATMUL");
2461+
}
2462+
24302463
// MERGE
24312464
mlir::Value IntrinsicLibrary::genMerge(mlir::Type,
24322465
llvm::ArrayRef<mlir::Value> args) {
@@ -2653,7 +2686,7 @@ IntrinsicLibrary::genReshape(mlir::Type resultType,
26532686
// Handle shape argument
26542687
auto shape = builder.createBox(loc, args[1]);
26552688
fir::BoxValue shapeTmp = shape;
2656-
auto shapeRank = shapeTmp.rank();
2689+
[[maybe_unused]] auto shapeRank = shapeTmp.rank();
26572690
assert(shapeRank == 1);
26582691
auto shapeTy = shape.getType();
26592692
auto shapeArrTy = fir::dyn_cast_ptrOrBoxEleTy(shapeTy);

flang/lib/Lower/TransformationalRuntime.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Lower/TransformationalRuntime.h"
10+
#include "../../runtime/matmul.h"
1011
#include "../../runtime/transformational.h"
1112
#include "RTBuilder.h"
1213
#include "flang/Lower/Bridge.h"
@@ -52,6 +53,21 @@ void Fortran::lower::genCshiftVector(Fortran::lower::FirOpBuilder &builder,
5253
builder.create<fir::CallOp>(loc, cshiftFunc, args);
5354
}
5455

56+
/// Generate call to Matmul intrinsic runtime routine.
57+
void Fortran::lower::genMatmul(Fortran::lower::FirOpBuilder &builder,
58+
mlir::Location loc, mlir::Value resultBox,
59+
mlir::Value matrixABox, mlir::Value matrixBBox) {
60+
auto func = Fortran::lower::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
61+
auto fTy = func.getType();
62+
auto sourceFile = Fortran::lower::locationToFilename(builder, loc);
63+
auto sourceLine =
64+
Fortran::lower::locationToLineNo(builder, loc, fTy.getInput(4));
65+
auto args =
66+
Fortran::lower::createArguments(builder, loc, fTy, resultBox, matrixABox,
67+
matrixBBox, sourceFile, sourceLine);
68+
builder.create<fir::CallOp>(loc, func, args);
69+
}
70+
5571
/// Generate call to Reshape intrinsic runtime routine.
5672
void Fortran::lower::genReshape(Fortran::lower::FirOpBuilder &builder,
5773
mlir::Location loc, mlir::Value resultBox,

flang/test/Lower/intrinsic-procedures.f90

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,71 @@ subroutine lge_test
906906
print*, llt(c1, c2)
907907
end
908908

909+
! MATMUL
910+
! CHECK-LABEL: matmul_test
911+
! CHECK-SAME: (%[[X:.*]]: !fir.ref<!fir.array<3x1xf32>>, %[[Y:.*]]: !fir.ref<!fir.array<1x3xf32>>, %[[Z:.*]]: !fir.ref<!fir.array<2x2xf32>>)
912+
! CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {uniq_name = ""}
913+
! CHECK: %[[C3:.*]] = constant 3 : index
914+
! CHECK: %[[C1:.*]] = constant 1 : index
915+
! CHECK: %[[C1_0:.*]] = constant 1 : index
916+
! CHECK: %[[C3_1:.*]] = constant 3 : index
917+
! CHECK: %[[Z_BOX:.*]] = fir.array_load %[[Z]]({{.*}}) : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> !fir.array<2x2xf32>
918+
! CHECK: %[[X_SHAPE:.*]] = fir.shape %[[C3]], %[[C1]] : (index, index) -> !fir.shape<2>
919+
! CHECK: %[[X_BOX:.*]] = fir.embox %[[X]](%[[X_SHAPE]]) : (!fir.ref<!fir.array<3x1xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<3x1xf32>>
920+
! CHECK: %[[Y_SHAPE:.*]] = fir.shape %[[C1_0]], %[[C3_1]] : (index, index) -> !fir.shape<2>
921+
! CHECK: %[[Y_BOX:.*]] = fir.embox %[[Y]](%[[Y_SHAPE]]) : (!fir.ref<!fir.array<1x3xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<1x3xf32>>
922+
! CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x?xf32>>
923+
! CHECK: %[[C0:.*]] = constant 0 : index
924+
! CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2>
925+
! CHECK: %[[RESULT_BOX_VAL:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.heap<!fir.array<?x?xf32>>>
926+
! CHECK: fir.store %[[RESULT_BOX_VAL]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
927+
! CHECK: %[[RESULT_BOX_ADDR_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
928+
! CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<3x1xf32>>) -> !fir.box<none>
929+
! CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<1x3xf32>>) -> !fir.box<none>
930+
! CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_ADDR_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
931+
! CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
932+
! CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
933+
! CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
934+
! CHECK: {{.*}}fir.array_fetch
935+
! CHECK: {{.*}}fir.array_update
936+
! CHECK: fir.result
937+
! CHECK: }
938+
! CHECK: fir.array_merge_store %[[Z_BOX]], %[[Z_COPY_FROM_RESULT]] to %[[Z]] : !fir.array<2x2xf32>, !fir.array<2x2xf32>, !fir.ref<!fir.array<2x2xf32>>
939+
! CHECK: fir.freemem %[[RESULT_TMP]] : !fir.heap<!fir.array<?x?xf32>>
940+
subroutine matmul_test(x,y,z)
941+
real :: x(3,1), y(1,3), z(2,2)
942+
z = matmul(x,y)
943+
end subroutine
944+
945+
! CHECK-LABEL: matmul_test2
946+
! CHECK-SAME: (%[[X_BOX:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[Y_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[Z_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>)
947+
!CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {uniq_name = ""}
948+
!CHECK: %[[Z:.*]] = fir.array_load %[[Z_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.array<?x!fir.logical<4>>
949+
!CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.logical<4>>>
950+
!CHECK: %[[C0:.*]] = constant 0 : index
951+
!CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1>
952+
!CHECK: %[[RESULT_BOX:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
953+
!CHECK: fir.store %[[RESULT_BOX]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
954+
!CHECK: %[[RESULT_BOX_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> !fir.ref<!fir.box<none>>
955+
!CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.box<none>
956+
!CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
957+
!CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
958+
!CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
959+
!CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>) -> !fir.heap<!fir.array<?x!fir.logical<4>>>
960+
!CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
961+
!CHECK: {{.*}}fir.array_fetch
962+
!CHECK: {{.*}}fir.array_update
963+
!CHECK: fir.result
964+
!CHECK: }
965+
!CHECK: fir.array_merge_store %[[Z]], %[[Z_COPY_FROM_RESULT]] to %[[Z_BOX]] : !fir.array<?x!fir.logical<4>>, !fir.array<?x!fir.logical<4>>, !fir.box<!fir.array<?x!fir.logical<4>>>
966+
!CHECK: fir.freemem %[[RESULT_TMP]] : !fir.heap<!fir.array<?x!fir.logical<4>>>
967+
subroutine matmul_test2(X, Y, Z)
968+
logical :: X(:,:)
969+
logical :: Y(:)
970+
logical :: Z(:)
971+
Z = matmul(X, Y)
972+
end subroutine
973+
909974
! MAXLOC
910975
! CHECK-LABEL: maxloc_test
911976
! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?xi32>>,

0 commit comments

Comments
 (0)