Skip to content

[mlir][tosa] Add FP8 lit tests #127730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025
Merged

[mlir][tosa] Add FP8 lit tests #127730

merged 1 commit into from
Mar 7, 2025

Conversation

Jerry-Ge
Copy link
Member

@Jerry-Ge Jerry-Ge commented Feb 19, 2025

Add FP8 lit tests to the following operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

Changes

Add FP8 support to following TOSA operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
DIM
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec.


Patch is 73.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127730.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+110-71)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+27-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+320-8)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (+3-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+13-8)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+11)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+16-18)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+284-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8947f7a9bd9a1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor: $input,
+    Tosa_Tensor_Extended: $input,
     I32Attr: $axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
@@ -73,7 +73,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
+
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
@@ -83,7 +84,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
@@ -102,7 +103,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
@@ -133,11 +134,12 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor5D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor5D_Extended:$input,
+    TensorRankOf<[Tosa_Weight], [5]>:$weight,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -146,7 +148,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   );
 
   let results = (outs
-    Tosa_Tensor5D:$output
+    Tosa_Tensor5D_Extended:$output
   );
 
   let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -178,7 +181,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -237,8 +240,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$a,
-    Tosa_Tensor3D:$b,
+    Tosa_Tensor3D_Extended:$a,
+    Tosa_Tensor3D_Extended:$b,
     OptionalAttr<I32Attr>:$a_zp,
     OptionalAttr<I32Attr>:$b_zp
   );
@@ -248,6 +251,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -264,7 +268,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
 
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
@@ -273,10 +277,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -327,11 +332,12 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
@@ -1190,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$pred,
-    Tosa_Tensor:$on_true,
-    Tosa_Tensor:$on_false
+    Tosa_I1Tensor:$input1,
+    Tosa_Tensor:$input2,
+    Tosa_Tensor:$input3
   );
 
   let results = (outs
@@ -1200,9 +1206,10 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   );
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
-    operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+    operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
     `)` `->` type($output)
   }];
 }
@@ -1518,16 +1525,17 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$input1,
+    Variadic<Tosa_Tensor_Extended>:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1563,14 +1571,14 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   }];
 
   let arguments = (ins
-    Tosa_RankedTensor:$input1,
+    Tosa_RankedTensor_Extended:$input1,
     Tosa_Shape:$padding,
-    Optional<Tosa_Rank0Tensor>:$pad_const,
+    Optional<Tosa_ScalarTensor_Extended>:$pad_const,
     OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_RankedTensor_Extended:$output
   );
 
   let builders = [Tosa_PadOpQuantInfoBuilder,
@@ -1597,12 +1605,12 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
   let hasVerifier = 1;
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$shape
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_RankedTensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1629,12 +1637,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasFolder = 1;
@@ -1656,13 +1664,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$start,
     Tosa_Shape:$size
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasCanonicalizer = 1;
@@ -1681,11 +1689,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$multiples);
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1709,12 +1717,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
-    Tosa_Int32Tensor:$perms
+    Tosa_Tensor_Extended:$input1,
+    Tosa_Int32Or64Tensor:$perms
   );
 
   let results = (
-    outs Tosa_Tensor:$output
+    outs Tosa_Tensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1743,13 +1751,14 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$values,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+    Tosa_Tensor3D_Extended:$values,
+    2DTensorOf<[Tosa_Int32]>:$indices
   );
 
   let results = (outs
-    Tosa_Tensor3D:$output
+    Tosa_Tensor3D_Extended:$output
   );
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1764,14 +1773,15 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$values_in,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
-    Tosa_Tensor3D:$input
+    Tosa_Tensor3D_Extended:$values_in,
+    2DTensorOf<[Tosa_Int32]>:$indices,
+    Tosa_Tensor3D_Extended:$input
   );
 
   let results = (outs
-    Tosa_Tensor3D:$values_out
+    Tosa_Tensor3D_Extended:$values_out
   );
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1828,37 +1838,66 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean |
-    | signed 16 to bool        | int16   | Boolean |
-    | signed 32 to bool        | int32   | Boolean |
-    | bool to 8                | Boolean | int8    |
-    | bool to 16               | Boolean | int16   |
-    | bool to 32               | Boolean | int32   |
-    | signed 8 to signed 16    | int8    | int16   |
-    | signed 8 to signed 32    | int8    | int32   |
-    | signed 16 to signed 8    | int16   | int8    |
-    | signed 16 to signed 32   | int16   | int32   |
-    | signed 32 to signed 8    | int32   | int8    |
-    | signed 32 to signed 16   | int32   | int16   |
-    | float to signed 8        | float   | int8    |
-    | float to signed 16       | float   | int16   |
-    | signed 8 to float        | int8    | float   |
-    | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 |
-    | float 64 to float 32     | float64 | float32 |
-  }];
-
-  let arguments = (ins
-    Tosa_Tensor:$input
-  );
-
-  let results = (outs
-    Tosa_Tensor:$output
+    | bool to int 8            | Boolean | int8    |
+    | bool to int 16           | Boolean | int16   |
+    | bool to int 32           | Boolean | int32   |
+    | int 8 to bool            | int8    | Boolean |
+    | int 8 to int 16          | int8    | int16   |
+    | int 8 to int 32          | int8    | int32   |
+    | int 8 to fp16            | int8    | float16 |
+    | int 8 to bf16            | int8    | bf16    |
+    | int 8 to fp32            | int8    | float32 |
+    | int 16 to bool           | int16   | Boolean |
+    | int 16 to int 8          | int16   | int8    |
+    | int 16 to int 32         | int16   | int32   |
+    | int 16 to fp16           | int16   | float16 |
+    | int 16 to bf16           | int16   | bf16    |
+    | int 16 to fp32           | int16   | float32 |
+    | int 32 to bool           | int32   | Boolean |
+    | int 32 to int 8          | int32   | int8    |
+    | int 32 to int 16         | int32   | int16   |
+    | int 32 to fp16           | int32   | float16 |
+    | int 32 to bf16           | int32   | bf16    |
+    | int 32 to fp32           | int32   | float32 |
+    | bf16 to int 8            | bf16    | int8    |
+    | bf16 to int 16           | bf16    | int16   |
+    | bf16 to int 32           | bf16    | int32   |
+    | bf16 to fp8e4m3          | bf16    | fp8e4m3 |
+    | bf16 to fp8e5m2          | bf16    | fp8e5m2 |
+    | bf16 to fp32             | bf16    | float32 |
+    | fp8e4m3 to fp16          | fp8e4m3 | float16 |
+    | fp8e4m3 to bf16          | fp8e4m3 | bf16    |
+    | fp8e4m3 to fp32          | fp8e4m3 | float32 |
+    | fp8e5m2 to fp16          | fp8e5m2 | float16 |
+    | fp8e5m2 to bf16          | fp8e5m2 | bf16    |
+    | fp8e5m2 to fp32          | fp8e5m2 | float32 |
+    | fp16 to int 8            | float16 | int8    |
+    | fp16 to int 16           | float16 | int16   |
+    | fp16 to int 32           | float16 | int32   |
+    | fp16 to fp8e4m3          | float16 | fp8e4m3 |
+    | fp16 to fp8e5m2          | float16 | fp8e5m2 |
+    | fp16 to fp32             | float16 | float32 |
+    | fp32 to int 8            | float32 | int8    |
+    | fp32 to int 16           | float32 | int16   |
+    | fp32 to int 32           | float32 | int32   |
+    | fp32 to fp8e4m3          | float32 | fp8e4m3 |
+    | fp32 to fp8e5m2          | float32 | fp8e5m2 |
+    | fp32 to bf16             | float32 | bf16    |
+    | fp32 to fp16             | float32 | float16 |
+  }];
+
+  let arguments = (ins
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$input
+  );
+
+  let results = (outs
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$output
   );
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1940,7 +1979,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64, Tosa_Int4]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..2c6e647ae32fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
                                    Tosa_QuantizedType<"int16", [16, 0], 1>,
                                    Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
+def Tosa_F8 : AnyTypeOf<[
+                        F8E4M3FN,
+                        F8E5M2]>;
+
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
 def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
+// Add F8 type support to Tosa_AnyNumber
+def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
+                               "number_extended">;
+
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
 // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
-                             Tosa_QuantizedInt, AnyFloat]>;
+                             Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
+
 
 //===----------------------------------------------------------------------===//
 // TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
 
 // Must be ranked but no further constraints
-def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,9 +156,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
 // Tensor types with constrained ranks.
 //===----------------------------------------------------------------------===//
 
-def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-
+// Scalar tensors: Rank-1 (with only one element)
 def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
 def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
 
 // We include unranked tensors as a supported type for all possible tosa
@@ -155,6 +166,7 @@ def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
 // they should be shape propagate used Tosa's shape inference pass and verified
 // to not include any remaining unranked tensors.
 def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
 
 def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
@@ -162,6 +174,17 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
 def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
+def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
+    "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
+    "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
+    "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
+    "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
+    "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
+
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 69b3f6d674167..704f8a82d11fa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
-  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
   if (!notOp)
     return failure();
   rewriter.modifyOpInPlace(op, [&]() {
     op.getOperation()->setOperands(
-        {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+        {notOp.getInput1(), op.getInput3(), op.getInput2()});
   });
   return success();
 }
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
-  if (getOnTrue() == getOnFalse())
-    return getOnTrue();
+  if (getInput2() == getInput3())
+    return getInput2();
 
   auto predicate =
-      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
   if (!predicate)
     return {};
 
   if (!predicate.isSplat())
     return {};
-  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
-                                                         : getOnFalse();
+  return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+                                                         : getInput3();
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..411f06f4a0b7c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,15 +217,17 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
 
 template <typename T>
 static LogicalResult verifyConvOp(T op) {
-  // All TOSA conv ops have an input and weight arguments which must be ranked
-  // tensors.
+  // All TOSA conv ops have an input() and weight().
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+
+  Ra...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Jerry-Ge (Jerry-Ge)

Changes

Add FP8 support to following TOSA operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
DIM
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec.


Patch is 73.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127730.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+110-71)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+27-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+320-8)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (+3-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+13-8)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+11)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+16-18)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+284-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8947f7a9bd9a1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor: $input,
+    Tosa_Tensor_Extended: $input,
     I32Attr: $axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
@@ -73,7 +73,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
+
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
@@ -83,7 +84,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
@@ -102,7 +103,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
@@ -133,11 +134,12 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor5D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor5D_Extended:$input,
+    TensorRankOf<[Tosa_Weight], [5]>:$weight,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -146,7 +148,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   );
 
   let results = (outs
-    Tosa_Tensor5D:$output
+    Tosa_Tensor5D_Extended:$output
   );
 
   let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -178,7 +181,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -237,8 +240,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$a,
-    Tosa_Tensor3D:$b,
+    Tosa_Tensor3D_Extended:$a,
+    Tosa_Tensor3D_Extended:$b,
     OptionalAttr<I32Attr>:$a_zp,
     OptionalAttr<I32Attr>:$b_zp
   );
@@ -248,6 +251,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -264,7 +268,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
 
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
@@ -273,10 +277,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   );
 
   let results = (outs
-    Tosa_Tensor4D:$output
+    Tosa_Tensor4D_Extended:$output
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -327,11 +332,12 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor4D:$input,
+    Tosa_Tensor4D_Extended:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
-    Tosa_Tensor1D:$bias,
+    Tosa_Tensor1D_Extended:$bias,
     Optional<Tosa_ScalarTensor>:$input_zp,
     Optional<Tosa_ScalarTensor>:$weight_zp,
+
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
@@ -1190,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$pred,
-    Tosa_Tensor:$on_true,
-    Tosa_Tensor:$on_false
+    Tosa_I1Tensor:$input1,
+    Tosa_Tensor:$input2,
+    Tosa_Tensor:$input3
   );
 
   let results = (outs
@@ -1200,9 +1206,10 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   );
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
-    operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+    operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
     `)` `->` type($output)
   }];
 }
@@ -1518,16 +1525,17 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$input1,
+    Variadic<Tosa_Tensor_Extended>:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1563,14 +1571,14 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   }];
 
   let arguments = (ins
-    Tosa_RankedTensor:$input1,
+    Tosa_RankedTensor_Extended:$input1,
     Tosa_Shape:$padding,
-    Optional<Tosa_Rank0Tensor>:$pad_const,
+    Optional<Tosa_ScalarTensor_Extended>:$pad_const,
     OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_RankedTensor_Extended:$output
   );
 
   let builders = [Tosa_PadOpQuantInfoBuilder,
@@ -1597,12 +1605,12 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
   let hasVerifier = 1;
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$shape
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_RankedTensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1629,12 +1637,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasFolder = 1;
@@ -1656,13 +1664,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$start,
     Tosa_Shape:$size
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let hasCanonicalizer = 1;
@@ -1681,11 +1689,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_Tensor_Extended:$input1,
     Tosa_Shape:$multiples);
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_Tensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1709,12 +1717,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
-    Tosa_Int32Tensor:$perms
+    Tosa_Tensor_Extended:$input1,
+    Tosa_Int32Or64Tensor:$perms
   );
 
   let results = (
-    outs Tosa_Tensor:$output
+    outs Tosa_Tensor_Extended:$output
   );
 
   let extraClassDeclaration = [{
@@ -1743,13 +1751,14 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$values,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+    Tosa_Tensor3D_Extended:$values,
+    2DTensorOf<[Tosa_Int32]>:$indices
   );
 
   let results = (outs
-    Tosa_Tensor3D:$output
+    Tosa_Tensor3D_Extended:$output
   );
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1764,14 +1773,15 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor3D:$values_in,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
-    Tosa_Tensor3D:$input
+    Tosa_Tensor3D_Extended:$values_in,
+    2DTensorOf<[Tosa_Int32]>:$indices,
+    Tosa_Tensor3D_Extended:$input
   );
 
   let results = (outs
-    Tosa_Tensor3D:$values_out
+    Tosa_Tensor3D_Extended:$values_out
   );
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1828,37 +1838,66 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean |
-    | signed 16 to bool        | int16   | Boolean |
-    | signed 32 to bool        | int32   | Boolean |
-    | bool to 8                | Boolean | int8    |
-    | bool to 16               | Boolean | int16   |
-    | bool to 32               | Boolean | int32   |
-    | signed 8 to signed 16    | int8    | int16   |
-    | signed 8 to signed 32    | int8    | int32   |
-    | signed 16 to signed 8    | int16   | int8    |
-    | signed 16 to signed 32   | int16   | int32   |
-    | signed 32 to signed 8    | int32   | int8    |
-    | signed 32 to signed 16   | int32   | int16   |
-    | float to signed 8        | float   | int8    |
-    | float to signed 16       | float   | int16   |
-    | signed 8 to float        | int8    | float   |
-    | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 |
-    | float 64 to float 32     | float64 | float32 |
-  }];
-
-  let arguments = (ins
-    Tosa_Tensor:$input
-  );
-
-  let results = (outs
-    Tosa_Tensor:$output
+    | bool to int 8            | Boolean | int8    |
+    | bool to int 16           | Boolean | int16   |
+    | bool to int 32           | Boolean | int32   |
+    | int 8 to bool            | int8    | Boolean |
+    | int 8 to int 16          | int8    | int16   |
+    | int 8 to int 32          | int8    | int32   |
+    | int 8 to fp16            | int8    | float16 |
+    | int 8 to bf16            | int8    | bf16    |
+    | int 8 to fp32            | int8    | float32 |
+    | int 16 to bool           | int16   | Boolean |
+    | int 16 to int 8          | int16   | int8    |
+    | int 16 to int 32         | int16   | int32   |
+    | int 16 to fp16           | int16   | float16 |
+    | int 16 to bf16           | int16   | bf16    |
+    | int 16 to fp32           | int16   | float32 |
+    | int 32 to bool           | int32   | Boolean |
+    | int 32 to int 8          | int32   | int8    |
+    | int 32 to int 16         | int32   | int16   |
+    | int 32 to fp16           | int32   | float16 |
+    | int 32 to bf16           | int32   | bf16    |
+    | int 32 to fp32           | int32   | float32 |
+    | bf16 to int 8            | bf16    | int8    |
+    | bf16 to int 16           | bf16    | int16   |
+    | bf16 to int 32           | bf16    | int32   |
+    | bf16 to fp8e4m3          | bf16    | fp8e4m3 |
+    | bf16 to fp8e5m2          | bf16    | fp8e5m2 |
+    | bf16 to fp32             | bf16    | float32 |
+    | fp8e4m3 to fp16          | fp8e4m3 | float16 |
+    | fp8e4m3 to bf16          | fp8e4m3 | bf16    |
+    | fp8e4m3 to fp32          | fp8e4m3 | float32 |
+    | fp8e5m2 to fp16          | fp8e5m2 | float16 |
+    | fp8e5m2 to bf16          | fp8e5m2 | bf16    |
+    | fp8e5m2 to fp32          | fp8e5m2 | float32 |
+    | fp16 to int 8            | float16 | int8    |
+    | fp16 to int 16           | float16 | int16   |
+    | fp16 to int 32           | float16 | int32   |
+    | fp16 to fp8e4m3          | float16 | fp8e4m3 |
+    | fp16 to fp8e5m2          | float16 | fp8e5m2 |
+    | fp16 to fp32             | float16 | float32 |
+    | fp32 to int 8            | float32 | int8    |
+    | fp32 to int 16           | float32 | int16   |
+    | fp32 to int 32           | float32 | int32   |
+    | fp32 to fp8e4m3          | float32 | fp8e4m3 |
+    | fp32 to fp8e5m2          | float32 | fp8e5m2 |
+    | fp32 to bf16             | float32 | bf16    |
+    | fp32 to fp16             | float32 | float16 |
+  }];
+
+  let arguments = (ins
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$input
+  );
+
+  let results = (outs
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$output
   );
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1940,7 +1979,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64, Tosa_Int4]>]>:$output
   );
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..2c6e647ae32fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
                                    Tosa_QuantizedType<"int16", [16, 0], 1>,
                                    Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
+def Tosa_F8 : AnyTypeOf<[
+                        F8E4M3FN,
+                        F8E5M2]>;
+
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
 def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
+// Add F8 type support to Tosa_AnyNumber
+def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
+                               "number_extended">;
+
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
 // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
-                             Tosa_QuantizedInt, AnyFloat]>;
+                             Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
+
 
 //===----------------------------------------------------------------------===//
 // TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
 
 // Must be ranked but no further constraints
-def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,9 +156,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
 // Tensor types with constrained ranks.
 //===----------------------------------------------------------------------===//
 
-def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-
+// Scalar tensors: Rank-1 (with only one element)
 def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
 def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
 
 // We include unranked tensors as a supported type for all possible tosa
@@ -155,6 +166,7 @@ def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
 // they should be shape propagate used Tosa's shape inference pass and verified
 // to not include any remaining unranked tensors.
 def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
 
 def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
@@ -162,6 +174,17 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
 def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
+def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
+    "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
+    "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
+    "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
+    "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
+    "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
+
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 69b3f6d674167..704f8a82d11fa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
-  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
   if (!notOp)
     return failure();
   rewriter.modifyOpInPlace(op, [&]() {
     op.getOperation()->setOperands(
-        {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+        {notOp.getInput1(), op.getInput3(), op.getInput2()});
   });
   return success();
 }
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
-  if (getOnTrue() == getOnFalse())
-    return getOnTrue();
+  if (getInput2() == getInput3())
+    return getInput2();
 
   auto predicate =
-      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
   if (!predicate)
     return {};
 
   if (!predicate.isSplat())
     return {};
-  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
-                                                         : getOnFalse();
+  return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+                                                         : getInput3();
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..411f06f4a0b7c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,15 +217,17 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
 
 template <typename T>
 static LogicalResult verifyConvOp(T op) {
-  // All TOSA conv ops have an input and weight arguments which must be ranked
-  // tensors.
+  // All TOSA conv ops have an input() and weight().
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+
+  Ra...
[truncated]

@Jerry-Ge Jerry-Ge requested a review from sjarus February 19, 2025 00:38
Copy link

github-actions bot commented Feb 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Comment on lines 475 to 478
if ((llvm::isa<Float8E5M2Type>(inputETy) ||
llvm::isa<Float8E4M3FNType>(inputETy)) &&
!accType.isF16())
return emitOpError("accumulator type for f8 tensor is not f16");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seems to restrictive. max value for Float8E5M2 type is 57344. Max value for fp16 accumulator is only around ~65k. Fp8 requires Fp32 accumulator not fp16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(added a related comment above before I saw this one: https://github.com/llvm/llvm-project/pull/127730/files#r1962023091)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is about validity of having fp16 accumulator not where we do that validation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @umangyadav. This is a reasonable request. We'll add another FP32 accumulator type likely in 1.1.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome we were able to simplify this patch so much :)

@Jerry-Ge Jerry-Ge changed the title [mlir][tosa] Add FP8 support [mlir][tosa] Add FP8 lit tests Mar 6, 2025
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Note - it needs a rebase over #129943

Add FP8 lit tests to the following operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

Signed-off-by: Tai Ly <[email protected]>
Signed-off-by: Jerry Ge <[email protected]>
Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
@Jerry-Ge Jerry-Ge merged commit ca582b1 into llvm:main Mar 7, 2025
11 checks passed
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
Add FP8 lit tests to the following operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

Signed-off-by: Tai Ly <[email protected]>
Signed-off-by: Jerry Ge <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants