@@ -783,3 +783,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
783
783
%cst = tosa.const_shape {values = dense <1 > : tensor <4 xindex >} : () -> !tosa.shape <4 >
784
784
return %cst : !tosa.shape <4 >
785
785
}
786
+
787
+ // F8 support tests
788
+
789
+ // -----
790
+ // CHECK-LABEL: argmax_f8E5M2
791
+ func.func @test_argmax_f8E5M2 (%arg0: tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 > {
792
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 >
793
+ return %0 : tensor <12 x16 xi32 >
794
+ }
795
+
796
+ // -----
797
+ // CHECK-LABEL: avg_pool2d_f8E5M2
798
+ func.func @test_avg_pool2d_f8E5M2 (%arg0: tensor <1 x7 x7 x9 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 > {
799
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
800
+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
801
+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 >
802
+ return %0 : tensor <1 x7 x7 x9 xf8 E5 M2 >
803
+ }
804
+
805
+ // -----
806
+ // CHECK-LABEL: conv2d_f8E5M2
807
+ func.func @test_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <8 x1 x1 x4 xf8 E5 M2 >, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
808
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
809
+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
810
+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <8 x1 x1 x4 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
811
+ return %0 : tensor <1 x4 x4 x8 xf16 >
812
+ }
813
+
814
+ // -----
815
+ // CHECK-LABEL: conv3d_f8E5M2
816
+ func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
817
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <34 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
818
+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
819
+ }
820
+
821
+ // -----
822
+ // CHECK-LABEL: depthwise_conv2d_f8E5M2
823
+ func.func @test_depthwise_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <1 x1 x4 x2 xf8 E5 M2 >, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 > {
824
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <1 x1 x4 x2 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
825
+ return %0 : tensor <1 x4 x4 x8 xf16 >
826
+ }
827
+
828
+ // -----
829
+ // CHECK-LABEL: test_matmul_f8E5M2
830
+ func.func @test_matmul_f8E5M2 (%arg0: tensor <1 x14 x19 xf8 E5 M2 >, %arg1: tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 > {
831
+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E5 M2 >, tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 >
832
+ return %0 : tensor <1 x14 x28 xf16 >
833
+ }
834
+
835
+ // -----
836
+ // CHECK-LABEL: max_pool2d_f8E5M2
837
+ func.func @test_max_pool2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 > {
838
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 >
839
+ return %0 : tensor <1 x32 x32 x8 xf8 E5 M2 >
840
+ }
841
+
842
+ // -----
843
+
844
+ // CHECK-LABEL: transpose_conv2d_f8E5M2
845
+ func.func @test_transpose_conv2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >, %arg1: tensor <16 x1 x1 x8 xf8 E5 M2 >, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 > {
846
+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >, tensor <16 x1 x1 x8 xf8 E5 M2 >, tensor <16 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 >
847
+ return %0 : tensor <1 x32 x32 x16 xf16 >
848
+ }
849
+
850
+ // -----
851
+ // CHECK-LABEL: const_f8E5M2
852
+ func.func @test_const_f8E5M2 (%arg0 : index ) -> tensor <4 xf8 E5 M2 > {
853
+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E5 M2 >} : () -> tensor <4 xf8 E5 M2 >
854
+ return %0 : tensor <4 xf8 E5 M2 >
855
+ }
856
+
857
+ // -----
858
+ // CHECK-LABEL: cast_f8E5M2
859
+ func.func @test_cast_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 > {
860
+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 >
861
+ return %0 : tensor <13 x21 x3 xf16 >
862
+ }
863
+
864
+ // -----
865
+ // CHECK-LABEL: concat_f8E5M2
866
+ func.func @test_concat_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 > {
867
+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 >
868
+ return %0 : tensor <26 x21 x3 xf8 E5 M2 >
869
+ }
870
+
871
+ // -----
872
+ // CHECK-LABEL: pad_f8E5M2
873
+ func.func @test_pad_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
874
+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
875
+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E5 M2 > } : () -> tensor <1 xf8 E5 M2 >
876
+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <6 >, tensor <1 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
877
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
878
+ }
879
+
880
+ // -----
881
+ // CHECK-LABEL: reshape_f8E5M2
882
+ func.func @test_reshape_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <1 x819 xf8 E5 M2 > {
883
+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
884
+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <2 >) -> tensor <1 x819 xf8 E5 M2 >
885
+ return %0 : tensor <1 x819 xf8 E5 M2 >
886
+ }
887
+
888
+ // -----
889
+ // CHECK-LABEL: reverse_f8E5M2
890
+ func.func @test_reverse_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
891
+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
892
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
893
+ }
894
+
895
+ // -----
896
+ // CHECK-LABEL: slice_f8E5M2
897
+ func.func @test_slice_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <4 x11 x1 xf8 E5 M2 > {
898
+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
899
+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
900
+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E5 M2 >
901
+ return %2 : tensor <4 x11 x1 xf8 E5 M2 >
902
+ }
903
+
904
+ // -----
905
+ // CHECK-LABEL: tile_f8E5M2
906
+ func.func @test_tile_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <39 x21 x6 xf8 E5 M2 > {
907
+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
908
+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E5 M2 >
909
+ return %0 : tensor <39 x21 x6 xf8 E5 M2 >
910
+ }
911
+
912
+ // -----
913
+ func.func @test_transpose_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 > {
914
+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 >
915
+ return %1 : tensor <3 x13 x21 xf8 E5 M2 >
916
+ }
917
+
918
+ // -----
919
+ // CHECK-LABEL: gather_f8E5M2
920
+ func.func @test_gather_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 > {
921
+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 >
922
+ return %0 : tensor <13 x26 x3 xf8 E5 M2 >
923
+ }
924
+
925
+ // -----
926
+ // CHECK-LABEL: scatter_f8E5M2
927
+ func.func @test_scatter_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
928
+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
929
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
930
+ }
931
+
932
+ // -----
933
+ // CHECK-LABEL: argmax_f8E4M3FN
934
+ func.func @test_argmax_f8E4M3FN (%arg0: tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 > {
935
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 >
936
+ return %0 : tensor <12 x16 xi32 >
937
+ }
938
+
939
+ // -----
940
+ // CHECK-LABEL: avg_pool2d_f8E4M3FN
941
+ func.func @test_avg_pool2d_f8E4M3FN (%arg0: tensor <1 x7 x7 x9 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN> {
942
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
943
+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
944
+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN>
945
+ return %0 : tensor <1 x7 x7 x9 xf8 E4 M3 FN>
946
+ }
947
+
948
+ // -----
949
+ // CHECK-LABEL: conv2d_f8E4M3FN
950
+ func.func @test_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <8 x1 x1 x4 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
951
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
952
+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
953
+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <8 x1 x1 x4 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
954
+ return %0 : tensor <1 x4 x4 x8 xf16 >
955
+ }
956
+
957
+ // -----
958
+ // CHECK-LABEL: conv3d_f8E4M3FN
959
+ func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
960
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <34 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
961
+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
962
+ }
963
+
964
+ // -----
965
+ // CHECK-LABEL: depthwise_conv2d_f8E4M3FN
966
+ func.func @test_depthwise_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <1 x1 x4 x2 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 > {
967
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <1 x1 x4 x2 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
968
+ return %0 : tensor <1 x4 x4 x8 xf16 >
969
+ }
970
+
971
+ // -----
972
+ // CHECK-LABEL: matmul_f8E4M3FN
973
+ func.func @test_matmul_f8E4M3FN (%arg0: tensor <1 x14 x19 xf8 E4 M3 FN>, %arg1: tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 > {
974
+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E4 M3 FN>, tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 >
975
+ return %0 : tensor <1 x14 x28 xf16 >
976
+ }
977
+
978
+ // -----
979
+ // CHECK-LABEL: max_pool2d_f8E4M3FN
980
+ func.func @test_max_pool2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN> {
981
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN>
982
+ return %0 : tensor <1 x32 x32 x8 xf8 E4 M3 FN>
983
+ }
984
+
985
+ // -----
986
+ // CHECK-LABEL: transpose_conv2d_f8E4M3FN
987
+ func.func @test_transpose_conv2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>, %arg1: tensor <16 x1 x1 x8 xf8 E4 M3 FN>, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 > {
988
+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>, tensor <16 x1 x1 x8 xf8 E4 M3 FN>, tensor <16 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 >
989
+ return %0 : tensor <1 x32 x32 x16 xf16 >
990
+ }
991
+
992
+ // -----
993
+ // CHECK-LABEL: const_f8E4M3FN
994
+ func.func @test_const_f8E4M3FN (%arg0 : index ) -> tensor <4 xf8 E4 M3 FN> {
995
+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E4 M3 FN>} : () -> tensor <4 xf8 E4 M3 FN>
996
+ return %0 : tensor <4 xf8 E4 M3 FN>
997
+ }
998
+
999
+ // -----
1000
+ // CHECK-LABEL: cast_f8E4M3FN
1001
+ func.func @test_cast_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 > {
1002
+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 >
1003
+ return %0 : tensor <13 x21 x3 xf16 >
1004
+ }
1005
+
1006
+ // -----
1007
+ // CHECK-LABEL: concat_f8E4M3FN
1008
+ func.func @test_concat_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN> {
1009
+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN>
1010
+ return %0 : tensor <26 x21 x3 xf8 E4 M3 FN>
1011
+ }
1012
+
1013
+ // -----
1014
+ // CHECK-LABEL: pad_f8E4M3FN
1015
+ func.func @test_pad_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1016
+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
1017
+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E4 M3 FN> } : () -> tensor <1 xf8 E4 M3 FN>
1018
+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <6 >, tensor <1 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1019
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1020
+ }
1021
+
1022
+ // -----
1023
+ // CHECK-LABEL: reshape_f8E4M3FN
1024
+ func.func @test_reshape_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <1 x819 xf8 E4 M3 FN> {
1025
+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
1026
+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <2 >) -> tensor <1 x819 xf8 E4 M3 FN>
1027
+ return %0 : tensor <1 x819 xf8 E4 M3 FN>
1028
+ }
1029
+
1030
+ // -----
1031
+ // CHECK-LABEL: reverse_f8E4M3FN
1032
+ func.func @test_reverse_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1033
+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1034
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1035
+ }
1036
+
1037
+ // -----
1038
+ // CHECK-LABEL: slice_f8E4M3FN
1039
+ func.func @test_slice_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <4 x11 x1 xf8 E4 M3 FN> {
1040
+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1041
+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1042
+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E4 M3 FN>
1043
+ return %2 : tensor <4 x11 x1 xf8 E4 M3 FN>
1044
+ }
1045
+
1046
+ // -----
1047
+ // CHECK-LABEL: tile_f8E4M3FN
1048
+ func.func @test_tile_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <39 x21 x6 xf8 E4 M3 FN> {
1049
+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
1050
+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E4 M3 FN>
1051
+ return %0 : tensor <39 x21 x6 xf8 E4 M3 FN>
1052
+ }
1053
+
1054
+ // -----
1055
+ // CHECK-LABEL: transpose_f8E4M3FN
1056
+ func.func @test_transpose_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN> {
1057
+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN>
1058
+ return %1 : tensor <3 x13 x21 xf8 E4 M3 FN>
1059
+ }
1060
+
1061
+ // -----
1062
+ // CHECK-LABEL: gather_f8E4M3FN
1063
+ func.func @test_gather_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN> {
1064
+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN>
1065
+ return %0 : tensor <13 x26 x3 xf8 E4 M3 FN>
1066
+ }
1067
+
1068
+ // -----
1069
+ // CHECK-LABEL: scatter_f8E4M3FN
1070
+ func.func @test_scatter_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1071
+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1072
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1073
+ }
0 commit comments