Skip to content

Commit dae7298

Browse files
Tai78641Jerry-Ge
authored andcommitted
[mlir][tosa] Add FP8 lit tests
Add FP8 lit tests to the following operators: ARGMAX AVGPOOL CONV2D CONV3D DEPTHWISE_CONV2D MATMUL MAX_POOL2D TRANSPOSE_CONV2D CONST CAST CONCAT PAD RESHAPE REVERSE SLICE TILE TRANSPOSE GATHER SCATTER Signed-off-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
1 parent d1bd1c7 commit dae7298

File tree

2 files changed

+289
-8
lines changed

2 files changed

+289
-8
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -522,14 +522,7 @@ LogicalResult tosa::AvgPool2dOp::verify() {
522522
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
523523
return failure();
524524

525-
if ((inputETy.isF32() && resultETy.isF32()) ||
526-
(inputETy.isF16() && resultETy.isF16()) ||
527-
(inputETy.isBF16() && resultETy.isBF16()) ||
528-
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
529-
(inputETy.isInteger(16) && resultETy.isInteger(16)))
530-
return success();
531-
532-
return emitOpError("input/output element types are incompatible.");
525+
return success();
533526
}
534527

535528
LogicalResult tosa::ClampOp::verify() {

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
783783
%cst = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
784784
return %cst : !tosa.shape<4>
785785
}
786+
787+
// F8 support tests
788+
789+
// -----
790+
// CHECK-LABEL: argmax_f8E5M2
791+
func.func @test_argmax_f8E5M2(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32> {
792+
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32>
793+
return %0 : tensor<12x16xi32>
794+
}
795+
796+
// -----
797+
// CHECK-LABEL: avg_pool2d_f8E5M2
798+
func.func @test_avg_pool2d_f8E5M2(%arg0: tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2> {
799+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
800+
%output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
801+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E5M2>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x7x7x9xf8E5M2>
802+
return %0 : tensor<1x7x7x9xf8E5M2>
803+
}
804+
805+
// -----
806+
// CHECK-LABEL: conv2d_f8E5M2
807+
func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
808+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
809+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
810+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
811+
return %0 : tensor<1x4x4x8xf16>
812+
}
813+
814+
// -----
815+
// CHECK-LABEL: conv3d_f8E5M2
816+
func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
817+
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
818+
return %0 : tensor<1x4x8x21x34xf16>
819+
}
820+
821+
// -----
822+
// CHECK-LABEL: depthwise_conv2d_f8E5M2
823+
func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16> {
824+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
825+
return %0 : tensor<1x4x4x8xf16>
826+
}
827+
828+
// -----
829+
// CHECK-LABEL: test_matmul_f8E5M2
830+
func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
831+
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16>
832+
return %0 : tensor<1x14x28xf16>
833+
}
834+
835+
// -----
836+
// CHECK-LABEL: max_pool2d_f8E5M2
837+
func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
838+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2>
839+
return %0 : tensor<1x32x32x8xf8E5M2>
840+
}
841+
842+
// -----
843+
844+
// CHECK-LABEL: transpose_conv2d_f8E5M2
845+
func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16> {
846+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
847+
return %0 : tensor<1x32x32x16xf16>
848+
}
849+
850+
// -----
851+
// CHECK-LABEL: const_f8E5M2
852+
func.func @test_const_f8E5M2(%arg0 : index) -> tensor<4xf8E5M2> {
853+
%0 = "tosa.const"() {values = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E5M2>} : () -> tensor<4xf8E5M2>
854+
return %0 : tensor<4xf8E5M2>
855+
}
856+
857+
// -----
858+
// CHECK-LABEL: cast_f8E5M2
859+
func.func @test_cast_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
860+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
861+
return %0 : tensor<13x21x3xf16>
862+
}
863+
864+
// -----
865+
// CHECK-LABEL: concat_f8E5M2
866+
func.func @test_concat_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2> {
867+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>, tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2>
868+
return %0 : tensor<26x21x3xf8E5M2>
869+
}
870+
871+
// -----
872+
// CHECK-LABEL: pad_f8E5M2
873+
func.func @test_pad_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
874+
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
875+
%cst = "tosa.const"() { values = dense<-0.0> : tensor<1xf8E5M2> } : () -> tensor<1xf8E5M2>
876+
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E5M2>, !tosa.shape<6>, tensor<1xf8E5M2>) -> tensor<13x21x3xf8E5M2>
877+
return %0 : tensor<13x21x3xf8E5M2>
878+
}
879+
880+
// -----
881+
// CHECK-LABEL: reshape_f8E5M2
882+
func.func @test_reshape_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<1x819xf8E5M2> {
883+
%1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
884+
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<2>) -> tensor<1x819xf8E5M2>
885+
return %0 : tensor<1x819xf8E5M2>
886+
}
887+
888+
// -----
889+
// CHECK-LABEL: reverse_f8E5M2
890+
func.func @test_reverse_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
891+
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
892+
return %0 : tensor<13x21x3xf8E5M2>
893+
}
894+
895+
// -----
896+
// CHECK-LABEL: slice_f8E5M2
897+
func.func @test_slice_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<4x11x1xf8E5M2> {
898+
%0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
899+
%1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
900+
%2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E5M2>
901+
return %2 : tensor<4x11x1xf8E5M2>
902+
}
903+
904+
// -----
905+
// CHECK-LABEL: tile_f8E5M2
906+
func.func @test_tile_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<39x21x6xf8E5M2> {
907+
%cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
908+
%0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E5M2>, !tosa.shape<3>) -> tensor<39x21x6xf8E5M2>
909+
return %0 : tensor<39x21x6xf8E5M2>
910+
}
911+
912+
// -----
913+
func.func @test_transpose_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2> {
914+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2>
915+
return %1 : tensor<3x13x21xf8E5M2>
916+
}
917+
918+
// -----
919+
// CHECK-LABEL: gather_f8E5M2
920+
func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2> {
921+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2>
922+
return %0 : tensor<13x26x3xf8E5M2>
923+
}
924+
925+
// -----
926+
// CHECK-LABEL: scatter_f8E5M2
927+
func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
928+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
929+
return %0 : tensor<13x21x3xf8E5M2>
930+
}
931+
932+
// -----
933+
// CHECK-LABEL: argmax_f8E4M3FN
934+
func.func @test_argmax_f8E4M3FN(%arg0: tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32> {
935+
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32>
936+
return %0 : tensor<12x16xi32>
937+
}
938+
939+
// -----
940+
// CHECK-LABEL: avg_pool2d_f8E4M3FN
941+
func.func @test_avg_pool2d_f8E4M3FN(%arg0: tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN> {
942+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
943+
%output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
944+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN>
945+
return %0 : tensor<1x7x7x9xf8E4M3FN>
946+
}
947+
948+
// -----
949+
// CHECK-LABEL: conv2d_f8E4M3FN
950+
func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8x1x1x4xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
951+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
952+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
953+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E4M3FN>, tensor<8x1x1x4xf8E4M3FN>, tensor<8xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16>
954+
return %0 : tensor<1x4x4x8xf16>
955+
}
956+
957+
// -----
958+
// CHECK-LABEL: conv3d_f8E4M3FN
959+
func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> {
960+
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16>
961+
return %0 : tensor<1x4x8x21x34xf16>
962+
}
963+
964+
// -----
965+
// CHECK-LABEL: depthwise_conv2d_f8E4M3FN
966+
func.func @test_depthwise_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<1x1x4x2xf8E4M3FN>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16> {
967+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E4M3FN>, tensor<1x1x4x2xf8E4M3FN>, tensor<8xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16>
968+
return %0 : tensor<1x4x4x8xf16>
969+
}
970+
971+
// -----
972+
// CHECK-LABEL: matmul_f8E4M3FN
973+
func.func @test_matmul_f8E4M3FN(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
974+
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16>
975+
return %0 : tensor<1x14x28xf16>
976+
}
977+
978+
// -----
979+
// CHECK-LABEL: max_pool2d_f8E4M3FN
980+
func.func @test_max_pool2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN> {
981+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN>
982+
return %0 : tensor<1x32x32x8xf8E4M3FN>
983+
}
984+
985+
// -----
986+
// CHECK-LABEL: transpose_conv2d_f8E4M3FN
987+
func.func @test_transpose_conv2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>, %arg1: tensor<16x1x1x8xf8E4M3FN>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x32x32x16xf16> {
988+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>, tensor<16x1x1x8xf8E4M3FN>, tensor<16xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x32x32x16xf16>
989+
return %0 : tensor<1x32x32x16xf16>
990+
}
991+
992+
// -----
993+
// CHECK-LABEL: const_f8E4M3FN
994+
func.func @test_const_f8E4M3FN(%arg0 : index) -> tensor<4xf8E4M3FN> {
995+
%0 = "tosa.const"() {values = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E4M3FN>} : () -> tensor<4xf8E4M3FN>
996+
return %0 : tensor<4xf8E4M3FN>
997+
}
998+
999+
// -----
1000+
// CHECK-LABEL: cast_f8E4M3FN
1001+
func.func @test_cast_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
1002+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
1003+
return %0 : tensor<13x21x3xf16>
1004+
}
1005+
1006+
// -----
1007+
// CHECK-LABEL: concat_f8E4M3FN
1008+
func.func @test_concat_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN> {
1009+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>, tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN>
1010+
return %0 : tensor<26x21x3xf8E4M3FN>
1011+
}
1012+
1013+
// -----
1014+
// CHECK-LABEL: pad_f8E4M3FN
1015+
func.func @test_pad_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1016+
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
1017+
%cst = "tosa.const"() { values = dense<-0.0> : tensor<1xf8E4M3FN> } : () -> tensor<1xf8E4M3FN>
1018+
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<1xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1019+
return %0 : tensor<13x21x3xf8E4M3FN>
1020+
}
1021+
1022+
// -----
1023+
// CHECK-LABEL: reshape_f8E4M3FN
1024+
func.func @test_reshape_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1x819xf8E4M3FN> {
1025+
%1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
1026+
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<2>) -> tensor<1x819xf8E4M3FN>
1027+
return %0 : tensor<1x819xf8E4M3FN>
1028+
}
1029+
1030+
// -----
1031+
// CHECK-LABEL: reverse_f8E4M3FN
1032+
func.func @test_reverse_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1033+
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1034+
return %0 : tensor<13x21x3xf8E4M3FN>
1035+
}
1036+
1037+
// -----
1038+
// CHECK-LABEL: slice_f8E4M3FN
1039+
func.func @test_slice_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<4x11x1xf8E4M3FN> {
1040+
%0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1041+
%1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
1042+
%2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E4M3FN>
1043+
return %2 : tensor<4x11x1xf8E4M3FN>
1044+
}
1045+
1046+
// -----
1047+
// CHECK-LABEL: tile_f8E4M3FN
1048+
func.func @test_tile_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<39x21x6xf8E4M3FN> {
1049+
%cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
1050+
%0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>) -> tensor<39x21x6xf8E4M3FN>
1051+
return %0 : tensor<39x21x6xf8E4M3FN>
1052+
}
1053+
1054+
// -----
1055+
// CHECK-LABEL: transpose_f8E4M3FN
1056+
func.func @test_transpose_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN> {
1057+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN>
1058+
return %1 : tensor<3x13x21xf8E4M3FN>
1059+
}
1060+
1061+
// -----
1062+
// CHECK-LABEL: gather_f8E4M3FN
1063+
func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN> {
1064+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN>
1065+
return %0 : tensor<13x26x3xf8E4M3FN>
1066+
}
1067+
1068+
// -----
1069+
// CHECK-LABEL: scatter_f8E4M3FN
1070+
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1071+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1072+
return %0 : tensor<13x21x3xf8E4M3FN>
1073+
}

0 commit comments

Comments
 (0)