Skip to content

Commit dbe480d

Browse files
Fixes to address TODOs and review comments
-> Handle case where rank of one the matrices is 1. -> Add test for the above case with logical type. -> Remove return value from the matmul call.
1 parent 019b009 commit dbe480d

File tree

4 files changed

+48
-18
lines changed

4 files changed

+48
-18
lines changed

flang/include/flang/Lower/TransformationalRuntime.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ void genCshiftVector(FirOpBuilder &builder, mlir::Location loc,
2626
mlir::Value resultBox, mlir::Value arrayBox,
2727
mlir::Value shiftBox);
2828

29-
mlir::Value genMatmul(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
30-
mlir::Value matrixABox, mlir::Value matrixBBox,
31-
mlir::Value resultBox);
29+
void genMatmul(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
30+
mlir::Value matrixABox, mlir::Value matrixBBox,
31+
mlir::Value resultBox);
3232

3333
void genReshape(FirOpBuilder &builder, mlir::Location loc,
3434
mlir::Value resultBox, mlir::Value sourceBox,

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,12 +2443,11 @@ IntrinsicLibrary::genMatmul(mlir::Type resultType,
24432443
mlir::Value matrixA = fir::getBase(matrixTmpA);
24442444
fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
24452445
mlir::Value matrixB = fir::getBase(matrixTmpB);
2446+
unsigned resultRank =
2447+
(matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
24462448

24472449
// Create mutable fir.box to be passed to the runtime for the result.
2448-
// TODO: The result can also be of rank 1 if one of the input matrices
2449-
// is of rank 1. Will this info be available at compile time or should
2450-
// code be generated to compute the rank?
2451-
auto resultArrayType = builder.getVarLenSeqTy(resultType, 2);
2450+
auto resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
24522451
auto resultMutableBox =
24532452
Fortran::lower::createTempMutableBox(builder, loc, resultArrayType);
24542453
auto resultIrBox =

flang/lib/Lower/TransformationalRuntime.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,18 @@ void Fortran::lower::genCshiftVector(Fortran::lower::FirOpBuilder &builder,
5454
}
5555

5656
/// Generate call to Matmul intrinsic runtime routine.
57-
mlir::Value Fortran::lower::genMatmul(Fortran::lower::FirOpBuilder &builder,
58-
mlir::Location loc, mlir::Value resultBox,
59-
mlir::Value matrixABox,
60-
mlir::Value matrixBBox) {
61-
mlir::FuncOp func;
62-
auto ty = matrixABox.getType();
63-
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
64-
65-
func = Fortran::lower::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
57+
void Fortran::lower::genMatmul(Fortran::lower::FirOpBuilder &builder,
58+
mlir::Location loc, mlir::Value resultBox,
59+
mlir::Value matrixABox, mlir::Value matrixBBox) {
60+
auto func = Fortran::lower::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
6661
auto fTy = func.getType();
6762
auto sourceFile = Fortran::lower::locationToFilename(builder, loc);
6863
auto sourceLine =
6964
Fortran::lower::locationToLineNo(builder, loc, fTy.getInput(4));
7065
auto args =
7166
Fortran::lower::createArguments(builder, loc, fTy, resultBox, matrixABox,
7267
matrixBBox, sourceFile, sourceLine);
73-
return builder.create<fir::CallOp>(loc, func, args).getResult(0);
68+
builder.create<fir::CallOp>(loc, func, args);
7469
}
7570

7671
/// Generate call to Reshape intrinsic runtime routine.

flang/test/Lower/intrinsic-procedures.f90

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,12 +908,13 @@ subroutine lge_test
908908

909909
! MATMUL
910910
! CHECK-LABEL: matmul_test
911-
! CHECK-SAME: (%[[X:.*]]: !fir.ref<!fir.array<3x1xf32>>, %[[Y:.*]]: !fir.ref<!fir.array<1x3xf32>>, {{.*}}: !fir.ref<!fir.array<2x2xf32>>)
911+
! CHECK-SAME: (%[[X:.*]]: !fir.ref<!fir.array<3x1xf32>>, %[[Y:.*]]: !fir.ref<!fir.array<1x3xf32>>, %[[Z:.*]]: !fir.ref<!fir.array<2x2xf32>>)
912912
! CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {uniq_name = ""}
913913
! CHECK: %[[C3:.*]] = constant 3 : index
914914
! CHECK: %[[C1:.*]] = constant 1 : index
915915
! CHECK: %[[C1_0:.*]] = constant 1 : index
916916
! CHECK: %[[C3_1:.*]] = constant 3 : index
917+
! CHECK: %[[Z_BOX:.*]] = fir.array_load %[[Z]]({{.*}}) : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> !fir.array<2x2xf32>
917918
! CHECK: %[[X_SHAPE:.*]] = fir.shape %[[C3]], %[[C1]] : (index, index) -> !fir.shape<2>
918919
! CHECK: %[[X_BOX:.*]] = fir.embox %[[X]](%[[X_SHAPE]]) : (!fir.ref<!fir.array<3x1xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<3x1xf32>>
919920
! CHECK: %[[Y_SHAPE:.*]] = fir.shape %[[C1_0]], %[[C3_1]] : (index, index) -> !fir.shape<2>
@@ -929,12 +930,47 @@ subroutine lge_test
929930
! CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_ADDR_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
930931
! CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
931932
! CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
933+
! CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
934+
! CHECK: {{.*}}fir.array_fetch
935+
! CHECK: {{.*}}fir.array_update
936+
! CHECK: fir.result
937+
! CHECK: }
938+
! CHECK: fir.array_merge_store %[[Z_BOX]], %[[Z_COPY_FROM_RESULT]] to %[[Z]] : !fir.array<2x2xf32>, !fir.array<2x2xf32>, !fir.ref<!fir.array<2x2xf32>>
932939
! CHECK: fir.freemem %[[RESULT_TMP]] : !fir.heap<!fir.array<?x?xf32>>
933940
subroutine matmul_test(x,y,z)
934941
real :: x(3,1), y(1,3), z(2,2)
935942
z = matmul(x,y)
936943
end subroutine
937944

945+
! CHECK-LABEL: matmul_test2
946+
! CHECK-SAME: (%[[X_BOX:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[Y_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[Z_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>)
947+
!CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {uniq_name = ""}
948+
!CHECK: %[[Z:.*]] = fir.array_load %[[Z_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.array<?x!fir.logical<4>>
949+
!CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.logical<4>>>
950+
!CHECK: %[[C0:.*]] = constant 0 : index
951+
!CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1>
952+
!CHECK: %[[RESULT_BOX:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
953+
!CHECK: fir.store %[[RESULT_BOX]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
954+
!CHECK: %[[RESULT_BOX_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> !fir.ref<!fir.box<none>>
955+
!CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.box<none>
956+
!CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
957+
!CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
958+
!CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
959+
!CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>) -> !fir.heap<!fir.array<?x!fir.logical<4>>>
960+
!CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
961+
!CHECK: {{.*}}fir.array_fetch
962+
!CHECK: {{.*}}fir.array_update
963+
!CHECK: fir.result
964+
!CHECK: }
965+
!CHECK: fir.array_merge_store %[[Z]], %[[Z_COPY_FROM_RESULT]] to %[[Z_BOX]] : !fir.array<?x!fir.logical<4>>, !fir.array<?x!fir.logical<4>>, !fir.box<!fir.array<?x!fir.logical<4>>>
966+
!CHECK: fir.freemem %[[RESULT_TMP]] : !fir.heap<!fir.array<?x!fir.logical<4>>>
967+
subroutine matmul_test2(X, Y, Z)
968+
logical :: X(:,:)
969+
logical :: Y(:)
970+
logical :: Z(:)
971+
Z = matmul(X, Y)
972+
end subroutine
973+
938974
! MAXLOC
939975
! CHECK-LABEL: maxloc_test
940976
! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?xi32>>,

0 commit comments

Comments
 (0)