Skip to content

Commit 55841ba

Browse files
tatwaichonglhutton1
authored andcommitted
[mlir][tosa] Enhance CONV3D & DEPTHWISE_CONV2D verifier (llvm#135738)
Verify the correctness of pad, stride, dilation, and dimension of input/weight/bias/output. Adapt and extend the existing conv2d error_if function to support additional convolution variants. (cherry-picked from commit e2a9902) Change-Id: Ic3bf041c3da8f8abbe8798b49948a811d63e187e
1 parent e1c0db0 commit 55841ba

File tree

11 files changed

+425
-218
lines changed

11 files changed

+425
-218
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 150 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,150 @@ static LogicalResult verifyConvOpModes(T op) {
428428
return success();
429429
}
430430

431+
//===----------------------------------------------------------------------===//
432+
// ERROR_IF functions.
433+
// ERROR_IF is a predicate that must set an error if the condition holds.
434+
//===----------------------------------------------------------------------===//
435+
436+
template <typename T>
437+
static LogicalResult verifyConvOpErrorIf(T op) {
438+
llvm::ArrayRef<int64_t> padding = op.getPad();
439+
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
440+
return op.emitOpError("expect all padding values to be >= 0, got ")
441+
<< padding;
442+
443+
llvm::ArrayRef<int64_t> strides = op.getStride();
444+
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
445+
return op.emitOpError("expect all stride values to be >= 1, got ")
446+
<< strides;
447+
448+
llvm::ArrayRef<int64_t> dilations = op.getDilation();
449+
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
450+
return op.emitOpError("expect all dilation values to be >= 1, got ")
451+
<< dilations;
452+
453+
const RankedTensorType outputType =
454+
llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
455+
if (!outputType)
456+
// Skip following checks if output is not ranked
457+
return success();
458+
459+
const RankedTensorType inputType =
460+
llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
461+
const RankedTensorType weightType =
462+
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
463+
464+
if (inputType && weightType) {
465+
const auto verifyOutputSize =
466+
[&op](const int64_t inputSize, const int64_t kernelSize,
467+
const int64_t outputSize, const int64_t padBefore,
468+
const int64_t padAfter, const int64_t stride,
469+
const int64_t dilation, const llvm::StringRef dimName,
470+
const llvm::StringRef dimAxis,
471+
const llvm::StringRef padBeforeName,
472+
const llvm::StringRef padAfterName) -> LogicalResult {
473+
if (inputSize == ShapedType::kDynamic ||
474+
kernelSize == ShapedType::kDynamic)
475+
return success();
476+
477+
// ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
478+
479+
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
480+
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
481+
stride);
482+
if (!calculatedOutSizeMinusOne.has_value())
483+
return op.emitOpError("expected input_")
484+
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
485+
<< padAfterName << " - (kernel_" << dimName
486+
<< " - 1) * dilation_" << dimAxis
487+
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
488+
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
489+
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
490+
<< stride;
491+
492+
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
493+
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
494+
return op.emitOpError("calculated output ")
495+
<< dimName << " did not match expected: "
496+
<< "calculated=" << calculatedOutSize
497+
<< ", expected=" << outputSize;
498+
499+
return success();
500+
};
501+
502+
// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
503+
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
504+
if (failed(verifyOutputSize(
505+
inputType.getDimSize(1), weightType.getDimSize(1),
506+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
507+
dilations[0], "height", "y", "top", "bottom")))
508+
return failure();
509+
510+
if (failed(verifyOutputSize(
511+
inputType.getDimSize(2), weightType.getDimSize(2),
512+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
513+
dilations[1], "width", "x", "left", "right")))
514+
return failure();
515+
}
516+
517+
// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
518+
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
519+
if (failed(verifyOutputSize(
520+
inputType.getDimSize(1), weightType.getDimSize(0),
521+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
522+
dilations[0], "height", "y", "top", "bottom")))
523+
return failure();
524+
525+
if (failed(verifyOutputSize(
526+
inputType.getDimSize(2), weightType.getDimSize(1),
527+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
528+
dilations[1], "width", "x", "left", "right")))
529+
return failure();
530+
}
531+
532+
// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
533+
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
534+
if (failed(verifyOutputSize(
535+
inputType.getDimSize(1), weightType.getDimSize(1),
536+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
537+
dilations[0], "depth", "d", "front", "back")))
538+
return failure();
539+
540+
if (failed(verifyOutputSize(
541+
inputType.getDimSize(2), weightType.getDimSize(2),
542+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
543+
dilations[1], "height", "y", "top", "bottom")))
544+
return failure();
545+
546+
if (failed(verifyOutputSize(
547+
inputType.getDimSize(3), weightType.getDimSize(3),
548+
outputType.getDimSize(3), padding[4], padding[5], strides[2],
549+
dilations[2], "width", "x", "left", "right")))
550+
return failure();
551+
}
552+
}
553+
554+
const RankedTensorType biasType =
555+
llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
556+
if (!biasType)
557+
// Skip following checks if bias is not ranked
558+
return success();
559+
560+
const int64_t biasChannels = biasType.getDimSize(0);
561+
const int64_t outputChannels = outputType.getDimSize(3);
562+
if (biasChannels == ShapedType::kDynamic ||
563+
outputChannels == ShapedType::kDynamic)
564+
// Skip following checks if biasChannels or outputChannels is dynamic dim
565+
return success();
566+
567+
if (biasChannels != outputChannels && biasChannels != 1)
568+
return op.emitOpError(
569+
"bias channels expected to be equal to output channels (")
570+
<< outputChannels << ") or 1, got " << biasChannels;
571+
572+
return success();
573+
}
574+
431575
// verify that inType and outType have same element types
432576
template <typename T>
433577
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -2767,99 +2911,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
27672911
}
27682912

27692913
LogicalResult Conv2DOp::verify() {
2770-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2914+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2915+
verifyConvOpErrorIf(*this).failed())
27712916
return failure();
2772-
2773-
llvm::ArrayRef<int64_t> padding = getPad();
2774-
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
2775-
return emitOpError("expect all padding values to be >= 0, got ") << padding;
2776-
2777-
llvm::ArrayRef<int64_t> strides = getStride();
2778-
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
2779-
return emitOpError("expect all stride values to be >= 1, got ") << strides;
2780-
2781-
llvm::ArrayRef<int64_t> dilations = getDilation();
2782-
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
2783-
return emitOpError("expect all dilation values to be >= 1, got ")
2784-
<< dilations;
2785-
2786-
const RankedTensorType outputType =
2787-
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2788-
if (!outputType)
2789-
// Skip following checks if output is not ranked
2790-
return success();
2791-
2792-
const RankedTensorType inputType =
2793-
llvm::dyn_cast<RankedTensorType>(getInput().getType());
2794-
const RankedTensorType weightType =
2795-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2796-
2797-
if (inputType && weightType) {
2798-
const auto verifyOutputSize =
2799-
[this](const int64_t inputSize, const int64_t kernelSize,
2800-
const int64_t outputSize, const int64_t padBefore,
2801-
const int64_t padAfter, const int64_t stride,
2802-
const int64_t dilation, const llvm::StringRef dimName,
2803-
const llvm::StringRef dimAxis,
2804-
const llvm::StringRef padBeforeName,
2805-
const llvm::StringRef padAfterName) -> LogicalResult {
2806-
if (inputSize == ShapedType::kDynamic ||
2807-
kernelSize == ShapedType::kDynamic)
2808-
return success();
2809-
2810-
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
2811-
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2812-
stride);
2813-
if (!calculatedOutSizeMinusOne.has_value())
2814-
return emitOpError("expected input_")
2815-
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2816-
<< padAfterName << " - (kernel_" << dimName
2817-
<< " - 1) * dilation_" << dimAxis
2818-
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
2819-
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
2820-
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
2821-
<< stride;
2822-
2823-
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2824-
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2825-
return emitOpError("calculated output ")
2826-
<< dimName << " did not match expected: "
2827-
<< "calculated=" << calculatedOutSize
2828-
<< ", expected=" << outputSize;
2829-
2830-
return success();
2831-
};
2832-
2833-
if (failed(verifyOutputSize(
2834-
inputType.getDimSize(1), weightType.getDimSize(1),
2835-
outputType.getDimSize(1), padding[0], padding[1], strides[0],
2836-
dilations[0], "height", "y", "top", "bottom")))
2837-
return failure();
2838-
2839-
if (failed(verifyOutputSize(
2840-
inputType.getDimSize(2), weightType.getDimSize(2),
2841-
outputType.getDimSize(2), padding[2], padding[3], strides[1],
2842-
dilations[1], "width", "x", "left", "right")))
2843-
return failure();
2844-
}
2845-
2846-
const RankedTensorType biasType =
2847-
llvm::dyn_cast<RankedTensorType>(getBias().getType());
2848-
if (!biasType)
2849-
// Skip following checks if bias is not ranked
2850-
return success();
2851-
2852-
const int64_t biasChannels = biasType.getDimSize(0);
2853-
const int64_t outputChannels = outputType.getDimSize(3);
2854-
if (biasChannels == ShapedType::kDynamic ||
2855-
outputChannels == ShapedType::kDynamic)
2856-
// Skip following checks if biasChannels or outputChannels is dynamic dim
2857-
return success();
2858-
2859-
if (biasChannels != outputChannels && biasChannels != 1)
2860-
return emitOpError(
2861-
"bias channels expected to be equal to output channels (")
2862-
<< outputChannels << ") or 1, got " << biasChannels;
28632917
return success();
28642918
}
28652919

@@ -2934,7 +2988,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
29342988
}
29352989

29362990
LogicalResult Conv3DOp::verify() {
2937-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2991+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2992+
verifyConvOpErrorIf(*this).failed())
29382993
return failure();
29392994
return success();
29402995
}
@@ -3044,7 +3099,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
30443099
}
30453100

30463101
LogicalResult DepthwiseConv2DOp::verify() {
3047-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3102+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3103+
verifyConvOpErrorIf(*this).failed())
30483104
return failure();
30493105
return success();
30503106
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -880,22 +880,22 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
880880
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
881881

882882
// CHECK-LABEL: @conv3d_f32
883-
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
884-
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
885-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
883+
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<43x3x4x5x27xf32>, %bias: tensor<43xf32>) -> () {
884+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xf32>) permutation = [1, 2, 3, 4, 0]
885+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xf32>
886886
// CHECK: %[[BROADCAST:.+]] = linalg.generic
887887
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
888-
// CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x47x45x43x28xf32>) {
888+
// CHECK-SAME: ins(%arg2 : tensor<43xf32>) outs(%[[INIT]] : tensor<1x47x45x43x43xf32>) {
889889
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
890890
// CHECK: linalg.yield %[[IN]] : f32
891-
// CHECK: } -> tensor<1x47x45x43x28xf32>
891+
// CHECK: } -> tensor<1x47x45x43x43xf32>
892892
// CHECK: linalg.conv_3d_ndhwc_dhwcf
893893
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
894-
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
895-
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
894+
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x43xf32>)
895+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
896896
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
897897
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
898-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
898+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<43x3x4x5x27xf32>, tensor<43xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x43xf32>
899899
return
900900
}
901901

@@ -921,40 +921,40 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
921921
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
922922

923923
// CHECK-LABEL: @conv3d_i8
924-
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
925-
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
926-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
924+
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<43x3x4x5x27xi8>, %bias: tensor<43xi32>) -> () {
925+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xi8>) permutation = [1, 2, 3, 4, 0]
926+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xi32>
927927
// CHECK: %[[BROADCAST:.+]] = linalg.generic
928928
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
929-
// CHECK-SAME: ins(%arg2 : tensor<28xi32>)
930-
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
929+
// CHECK-SAME: ins(%arg2 : tensor<43xi32>)
930+
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x43xi32>) {
931931
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
932932
// CHECK: linalg.yield %[[IN]] : i32
933-
// CHECK: } -> tensor<1x47x45x43x28xi32>
933+
// CHECK: } -> tensor<1x47x45x43x43xi32>
934934
// CHECK: %[[IZP:.+]] = arith.constant -128 : i32
935935
// CHECK: %[[FZP:.+]] = arith.constant 42 : i32
936936
// CHECK: linalg.conv_3d_ndhwc_dhwcf_q
937937
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
938-
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
939-
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
938+
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x43xi8>, i32, i32)
939+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xi32>) -> tensor<1x47x45x43x43xi32>
940940

941941
%input_zp = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
942942
%weight_zp = "tosa.const"() <{values = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
943-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32>
943+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<43x3x4x5x27xi8>, tensor<43xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x43xi32>
944944
return
945945
}
946946

947947
// -----
948948

949949
// CHECK-LABEL: @conv3d_f16_f32_acc
950-
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
950+
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf16>) -> () {
951951
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
952952
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
953-
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
953+
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
954954
// CHECK: arith.extf %{{.*}} : f16 to f32
955-
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
956-
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
957-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
955+
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
956+
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
957+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf16>
958958
return
959959
}
960960

0 commit comments

Comments
 (0)