Skip to content

Commit 25a29ce

Browse files
Tai78641lhutton1
andauthored
[mlir][tosa] Switch zero point of avgpool2d to input variable type (#128983)
This commit changes the TOSA operator AvgPool2d's zero point attributes to inputs to align with TOSA 1.0 spec. Signed-off-by: Luke Hutton <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent 17bfc00 commit 25a29ce

18 files changed

+355
-204
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ profileComplianceMap = {
55
{{{Profile::pro_int}, {{i8T, i32T}}},
66
{{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
77
{"tosa.avg_pool2d",
8-
{{{Profile::pro_int}, {{i8T, i32T, i8T}}},
8+
{{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
99
{{Profile::pro_fp},
10-
{{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
10+
{{fp16T, fp16T, fp16T, fp16T, fp16T},
11+
{fp16T, fp16T, fp16T, fp32T, fp16T},
12+
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
1113
{"tosa.conv2d",
1214
{{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
1315
{{Profile::pro_fp},
@@ -243,10 +245,10 @@ extensionComplianceMap = {
243245
{{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
244246
{{Extension::bf16}, {{bf16T, i32T}}}}},
245247
{"tosa.avg_pool2d",
246-
{{{Extension::int16}, {{i16T, i32T, i16T}}},
247-
{{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
248-
{{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
249-
{{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
248+
{{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
249+
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
250+
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
251+
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
250252
{"tosa.conv2d",
251253
{{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
252254
{{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
7979

8080
let arguments = (ins
8181
Tosa_Tensor4D:$input,
82+
Tosa_ScalarIntOrFloatTensor:$input_zp,
83+
Tosa_ScalarIntOrFloatTensor:$output_zp,
8284
Tosa_IntArrayAttr2:$kernel,
8385
Tosa_IntArrayAttr2:$stride,
8486
Tosa_IntArrayAttr4:$pad,
85-
TypeAttrOf<Tosa_AccType>:$acc_type,
86-
OptionalAttr<I32Attr>:$input_zp,
87-
OptionalAttr<I32Attr>:$output_zp
87+
TypeAttrOf<Tosa_AccType>:$acc_type
8888
);
8989

9090
let results = (outs
@@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
9797
];
9898

9999
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
100+
101+
let extraClassDeclaration = [{
102+
FailureOr<int64_t> getInputZeroPoint();
103+
FailureOr<int64_t> getOutputZeroPoint();
104+
LogicalResult verifyInputZeroPoint(int64_t zp);
105+
LogicalResult verifyOutputZeroPoint(int64_t zp);
106+
}];
107+
100108
let hasVerifier = 1;
101109
}
102110

@@ -116,8 +124,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
116124
Tosa_Tensor4D:$input,
117125
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
118126
Tosa_Tensor1D:$bias,
119-
Tosa_ScalarTensor:$input_zp,
120-
Tosa_ScalarTensor:$weight_zp,
127+
Tosa_ScalarIntOrFloatTensor:$input_zp,
128+
Tosa_ScalarIntOrFloatTensor:$weight_zp,
121129

122130
Tosa_IntArrayAttr4:$pad,
123131
Tosa_IntArrayAttr2:$stride,
@@ -136,8 +144,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
136144
];
137145

138146
let extraClassDeclaration = [{
139-
LogicalResult getInputZeroPoint(int64_t &zp);
140-
LogicalResult getWeightZeroPoint(int64_t &zp);
147+
FailureOr<int64_t> getInputZeroPoint();
148+
FailureOr<int64_t> getWeightZeroPoint();
141149
LogicalResult verifyInputZeroPoint(int64_t zp);
142150
LogicalResult verifyWeightZeroPoint(int64_t zp);
143151
}];
@@ -161,8 +169,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
161169
Tosa_Tensor5D:$input,
162170
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
163171
Tosa_Tensor1D:$bias,
164-
Tosa_ScalarTensor:$input_zp,
165-
Tosa_ScalarTensor:$weight_zp,
172+
Tosa_ScalarIntOrFloatTensor:$input_zp,
173+
Tosa_ScalarIntOrFloatTensor:$weight_zp,
166174

167175
Tosa_IntArrayAttr6:$pad,
168176
Tosa_IntArrayAttr3:$stride,
@@ -181,8 +189,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
181189
];
182190

183191
let extraClassDeclaration = [{
184-
LogicalResult getInputZeroPoint(int64_t &zp);
185-
LogicalResult getWeightZeroPoint(int64_t &zp);
192+
FailureOr<int64_t> getInputZeroPoint();
193+
FailureOr<int64_t> getWeightZeroPoint();
186194
LogicalResult verifyInputZeroPoint(int64_t zp);
187195
LogicalResult verifyWeightZeroPoint(int64_t zp);
188196
}];
@@ -207,8 +215,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
207215
Tosa_Tensor4D:$input,
208216
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
209217
Tosa_Tensor1D:$bias,
210-
Tosa_ScalarTensor:$input_zp,
211-
Tosa_ScalarTensor:$weight_zp,
218+
Tosa_ScalarIntOrFloatTensor:$input_zp,
219+
Tosa_ScalarIntOrFloatTensor:$weight_zp,
212220

213221
Tosa_IntArrayAttr4:$pad,
214222
Tosa_IntArrayAttr2:$stride,
@@ -227,8 +235,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
227235
];
228236

229237
let extraClassDeclaration = [{
230-
LogicalResult getInputZeroPoint(int64_t &zp);
231-
LogicalResult getWeightZeroPoint(int64_t &zp);
238+
FailureOr<int64_t> getInputZeroPoint();
239+
FailureOr<int64_t> getWeightZeroPoint();
232240
LogicalResult verifyInputZeroPoint(int64_t zp);
233241
LogicalResult verifyWeightZeroPoint(int64_t zp);
234242
}];
@@ -412,8 +420,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
412420
Tosa_Tensor4D:$input,
413421
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
414422
Tosa_Tensor1D:$bias,
415-
Tosa_ScalarTensor:$input_zp,
416-
Tosa_ScalarTensor:$weight_zp,
423+
Tosa_ScalarIntOrFloatTensor:$input_zp,
424+
Tosa_ScalarIntOrFloatTensor:$weight_zp,
417425

418426
Tosa_IntArrayAttr4:$out_pad,
419427
Tosa_IntArrayAttr2:$stride,
@@ -431,8 +439,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
431439
];
432440

433441
let extraClassDeclaration = [{
434-
LogicalResult getInputZeroPoint(int64_t &zp);
435-
LogicalResult getWeightZeroPoint(int64_t &zp);
442+
FailureOr<int64_t> getInputZeroPoint();
443+
FailureOr<int64_t> getWeightZeroPoint();
436444
LogicalResult verifyInputZeroPoint(int64_t zp);
437445
LogicalResult verifyWeightZeroPoint(int64_t zp);
438446
}];

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
149149

150150
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
151151
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
152+
def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;
152153

153154
// We include unranked tensors as a supported type for all possible tosa
154155
// Tensors as unranked does not guarantee invalid. If unranked tensors exist

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,26 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
260260
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261261

262262
// Get and verify zero points.
263-
int64_t inputZpVal;
264-
int64_t weightZpVal;
263+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
264+
if (failed(maybeIZp))
265+
return rewriter.notifyMatchFailure(
266+
op, "input zero point cannot be statically determined");
267+
268+
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
269+
if (failed(maybeWZp))
270+
return rewriter.notifyMatchFailure(
271+
op, "weight zero point cannot be statically determined");
265272

266-
if (op.getInputZeroPoint(inputZpVal).failed() ||
267-
op.getWeightZeroPoint(weightZpVal).failed())
273+
int64_t inputZpVal = *maybeIZp;
274+
int64_t weightZpVal = *maybeWZp;
275+
276+
if (op.verifyInputZeroPoint(inputZpVal).failed())
268277
return rewriter.notifyMatchFailure(
269-
op, "bail out if zero points cannot statically be determined");
278+
op, "input zero point must be zero for non-int8 integer types");
270279

271-
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
272-
op.verifyWeightZeroPoint(weightZpVal).failed())
280+
if (op.verifyWeightZeroPoint(weightZpVal).failed())
273281
return rewriter.notifyMatchFailure(
274-
op, "zero point must be zero for non-int8 integer types");
282+
op, "weight zero point must be zero for non-int8 integer types");
275283

276284
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
277285

@@ -448,18 +456,26 @@ class DepthwiseConvConverter
448456
/*kernelSizeDims=*/{0, 1}, rewriter);
449457

450458
// Get and verify zero points.
451-
int64_t inputZpVal;
452-
int64_t weightZpVal;
453459

454-
if (op.getInputZeroPoint(inputZpVal).failed() ||
455-
op.getWeightZeroPoint(weightZpVal).failed())
460+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
461+
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
462+
if (failed(maybeIZp))
463+
return rewriter.notifyMatchFailure(
464+
op, "input zero point cannot be statically determined");
465+
if (failed(maybeWZp))
466+
return rewriter.notifyMatchFailure(
467+
op, "weight zero point cannot be statically determined");
468+
469+
int64_t inputZpVal = *maybeIZp;
470+
int64_t weightZpVal = *maybeWZp;
471+
472+
if (op.verifyInputZeroPoint(inputZpVal).failed())
456473
return rewriter.notifyMatchFailure(
457-
op, "bail out if zero points cannot statically be determined");
474+
op, "input zero point must be zero for non-int8 integer types");
458475

459-
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
460-
op.verifyWeightZeroPoint(weightZpVal).failed())
476+
if (op.verifyWeightZeroPoint(weightZpVal).failed())
461477
return rewriter.notifyMatchFailure(
462-
op, "zero point must be zero for non-int8 integer types");
478+
op, "weight zero point must be zero for non-int8 integer types");
463479

464480
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
465481
auto weightShape = weightTy.getShape();
@@ -809,6 +825,18 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
809825
return failure();
810826
SmallVector<Value> dynamicDims = *dynamicDimsOr;
811827

828+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
829+
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
830+
if (failed(maybeIZp))
831+
return rewriter.notifyMatchFailure(
832+
op, "input zero point could not be statically determined");
833+
if (failed(maybeOZp))
834+
return rewriter.notifyMatchFailure(
835+
op, "output zero point could not be statically determined");
836+
837+
int64_t inputZpVal = *maybeIZp;
838+
int64_t outputZpVal = *maybeOZp;
839+
812840
// Apply padding as necessary.
813841
llvm::SmallVector<int64_t> pad;
814842
pad.resize(2, 0);
@@ -928,9 +956,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
928956

929957
// If we have quantization information we need to apply an offset
930958
// for the input zp value.
931-
if (op.getInputZp()) {
932-
auto inputZp =
933-
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
959+
if (inputZpVal != 0) {
960+
auto inputZp = rewriter.create<arith::ConstantOp>(
961+
loc, b.getIntegerAttr(accETy, inputZpVal));
934962
Value offset =
935963
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
936964
poolVal =
@@ -982,9 +1010,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
9821010

9831011
// If we have quantization information we need to apply output
9841012
// zeropoint.
985-
if (op.getOutputZp()) {
986-
auto outputZp =
987-
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
1013+
if (outputZpVal != 0) {
1014+
auto outputZp = rewriter.create<arith::ConstantOp>(
1015+
loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
9881016
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
9891017
.getResult();
9901018
}

0 commit comments

Comments
 (0)