Skip to content

[mlir][tosa] Switch zero point of avgpool2d to input variable type #128983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i8T, i32T}}},
{{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Profile::pro_int}, {{i8T, i32T, i8T}}},
{{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
{{Profile::pro_fp},
{{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{{fp16T, fp16T, fp16T, fp16T, fp16T},
{fp16T, fp16T, fp16T, fp32T, fp16T},
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.conv2d",
{{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
{{Profile::pro_fp},
Expand Down Expand Up @@ -243,10 +245,10 @@ extensionComplianceMap = {
{{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
{{Extension::bf16}, {{bf16T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Extension::int16}, {{i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
{{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.conv2d",
{{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
{{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
Expand Down
46 changes: 27 additions & 19 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<I32Attr>:$input_zp,
OptionalAttr<I32Attr>:$output_zp
TypeAttrOf<Tosa_AccType>:$acc_type
);

let results = (outs
Expand All @@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
];

let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];

let extraClassDeclaration = [{
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getOutputZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasVerifier = 1;
}

Expand All @@ -116,8 +124,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Expand All @@ -136,8 +144,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand All @@ -161,8 +169,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Expand All @@ -181,8 +189,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand All @@ -207,8 +215,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Expand All @@ -227,8 +235,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand Down Expand Up @@ -412,8 +420,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Expand All @@ -431,8 +439,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;

def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;

// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
Expand Down
72 changes: 50 additions & 22 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,26 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();

// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");

FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");

if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;

if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point must be zero for non-int8 integer types");

if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point must be zero for non-int8 integer types");

bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);

Expand Down Expand Up @@ -448,18 +456,26 @@ class DepthwiseConvConverter
/*kernelSizeDims=*/{0, 1}, rewriter);

// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;

if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");

int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;

if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point must be zero for non-int8 integer types");

if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point must be zero for non-int8 integer types");

bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
auto weightShape = weightTy.getShape();
Expand Down Expand Up @@ -809,6 +825,18 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
return failure();
SmallVector<Value> dynamicDims = *dynamicDimsOr;

FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point could not be statically determined");
if (failed(maybeOZp))
return rewriter.notifyMatchFailure(
op, "output zero point could not be statically determined");

int64_t inputZpVal = *maybeIZp;
int64_t outputZpVal = *maybeOZp;

// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
Expand Down Expand Up @@ -928,9 +956,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

// If we have quantization information we need to apply an offset
// for the input zp value.
if (op.getInputZp()) {
auto inputZp =
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
if (inputZpVal != 0) {
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(accETy, inputZpVal));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
Expand Down Expand Up @@ -982,9 +1010,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

// If we have quantization information we need to apply output
// zeropoint.
if (op.getOutputZp()) {
auto outputZp =
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
if (outputZpVal != 0) {
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
Expand Down
Loading