Skip to content

Commit cd60e10

Browse files
committed
[mlir][tosa] Enhance CONV3D & DEPTHWISE_CONV2D verifier
Verify the pad, stride, dilation, and dimension of input/output.
1 parent 3633de7 commit cd60e10

File tree

11 files changed

+427
-218
lines changed

11 files changed

+427
-218
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 152 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,152 @@ static LogicalResult verifyConvOpModes(T op) {
428428
return success();
429429
}
430430

431+
//===----------------------------------------------------------------------===//
432+
// ERROR_IF functions.
433+
// ERROR_IF is a predicate that must set an error if the condition holds.
434+
//===----------------------------------------------------------------------===//
435+
436+
template <typename T>
437+
static LogicalResult verifyConvOpErrorIf(T op) {
438+
llvm::ArrayRef<int64_t> padding = op.getPad();
439+
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
440+
return op.emitOpError("expect all padding values to be >= 0, got ")
441+
<< padding;
442+
443+
llvm::ArrayRef<int64_t> strides = op.getStride();
444+
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
445+
return op.emitOpError("expect all stride values to be >= 1, got ")
446+
<< strides;
447+
448+
llvm::ArrayRef<int64_t> dilations = op.getDilation();
449+
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
450+
return op.emitOpError("expect all dilation values to be >= 1, got ")
451+
<< dilations;
452+
453+
const RankedTensorType outputType =
454+
llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
455+
if (!outputType)
456+
// Skip following checks if output is not ranked
457+
return success();
458+
459+
const RankedTensorType inputType =
460+
llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
461+
const RankedTensorType weightType =
462+
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
463+
464+
if (inputType && weightType) {
465+
const auto verifyOutputSize =
466+
[&op](const int64_t inputSize, const int64_t kernelSize,
467+
const int64_t outputSize, const int64_t padBefore,
468+
const int64_t padAfter, const int64_t stride,
469+
const int64_t dilation, const llvm::StringRef dimName,
470+
const llvm::StringRef dimAxis,
471+
const llvm::StringRef padBeforeName,
472+
const llvm::StringRef padAfterName) -> LogicalResult {
473+
if (inputSize == ShapedType::kDynamic ||
474+
kernelSize == ShapedType::kDynamic)
475+
return success();
476+
477+
// ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
478+
479+
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
480+
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
481+
stride);
482+
if (!calculatedOutSizeMinusOne.has_value())
483+
return op.emitOpError("expected input_")
484+
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
485+
<< padAfterName << " - (kernel_" << dimName
486+
<< " - 1) * dilation_" << dimAxis
487+
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
488+
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
489+
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
490+
<< stride;
491+
492+
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
493+
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
494+
return op.emitOpError("calculated output ")
495+
<< dimName << " did not match expected: "
496+
<< "calculated=" << calculatedOutSize
497+
<< ", expected=" << outputSize;
498+
499+
return success();
500+
};
501+
502+
/// ERROR_IF: O != idiv_check(I - 1 + p_a + p_b - (K - 1) * d, s) + 1
503+
504+
// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
505+
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
506+
if (failed(verifyOutputSize(
507+
inputType.getDimSize(1), weightType.getDimSize(1),
508+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
509+
dilations[0], "height", "y", "top", "bottom")))
510+
return failure();
511+
512+
if (failed(verifyOutputSize(
513+
inputType.getDimSize(2), weightType.getDimSize(2),
514+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
515+
dilations[1], "width", "x", "left", "right")))
516+
return failure();
517+
}
518+
519+
// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
520+
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
521+
if (failed(verifyOutputSize(
522+
inputType.getDimSize(1), weightType.getDimSize(0),
523+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
524+
dilations[0], "height", "y", "top", "bottom")))
525+
return failure();
526+
527+
if (failed(verifyOutputSize(
528+
inputType.getDimSize(2), weightType.getDimSize(1),
529+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
530+
dilations[1], "width", "x", "left", "right")))
531+
return failure();
532+
}
533+
534+
// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
535+
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
536+
if (failed(verifyOutputSize(
537+
inputType.getDimSize(1), weightType.getDimSize(1),
538+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
539+
dilations[0], "depth", "d", "front", "back")))
540+
return failure();
541+
542+
if (failed(verifyOutputSize(
543+
inputType.getDimSize(2), weightType.getDimSize(2),
544+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
545+
dilations[1], "height", "y", "top", "bottom")))
546+
return failure();
547+
548+
if (failed(verifyOutputSize(
549+
inputType.getDimSize(3), weightType.getDimSize(3),
550+
outputType.getDimSize(3), padding[4], padding[5], strides[2],
551+
dilations[2], "width", "x", "left", "right")))
552+
return failure();
553+
}
554+
}
555+
556+
const RankedTensorType biasType =
557+
llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
558+
if (!biasType)
559+
// Skip following checks if bias is not ranked
560+
return success();
561+
562+
const int64_t biasChannels = biasType.getDimSize(0);
563+
const int64_t outputChannels = outputType.getDimSize(3);
564+
if (biasChannels == ShapedType::kDynamic ||
565+
outputChannels == ShapedType::kDynamic)
566+
// Skip following checks if biasChannels or outputChannels is dynamic dim
567+
return success();
568+
569+
if (biasChannels != outputChannels && biasChannels != 1)
570+
return op.emitOpError(
571+
"bias channels expected to be equal to output channels (")
572+
<< outputChannels << ") or 1, got " << biasChannels;
573+
574+
return success();
575+
}
576+
431577
// verify that inType and outType have same element types
432578
template <typename T>
433579
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -2570,99 +2716,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
25702716
}
25712717

25722718
LogicalResult Conv2DOp::verify() {
2573-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2719+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2720+
verifyConvOpErrorIf(*this).failed())
25742721
return failure();
2575-
2576-
llvm::ArrayRef<int64_t> padding = getPad();
2577-
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
2578-
return emitOpError("expect all padding values to be >= 0, got ") << padding;
2579-
2580-
llvm::ArrayRef<int64_t> strides = getStride();
2581-
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
2582-
return emitOpError("expect all stride values to be >= 1, got ") << strides;
2583-
2584-
llvm::ArrayRef<int64_t> dilations = getDilation();
2585-
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
2586-
return emitOpError("expect all dilation values to be >= 1, got ")
2587-
<< dilations;
2588-
2589-
const RankedTensorType outputType =
2590-
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2591-
if (!outputType)
2592-
// Skip following checks if output is not ranked
2593-
return success();
2594-
2595-
const RankedTensorType inputType =
2596-
llvm::dyn_cast<RankedTensorType>(getInput().getType());
2597-
const RankedTensorType weightType =
2598-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2599-
2600-
if (inputType && weightType) {
2601-
const auto verifyOutputSize =
2602-
[this](const int64_t inputSize, const int64_t kernelSize,
2603-
const int64_t outputSize, const int64_t padBefore,
2604-
const int64_t padAfter, const int64_t stride,
2605-
const int64_t dilation, const llvm::StringRef dimName,
2606-
const llvm::StringRef dimAxis,
2607-
const llvm::StringRef padBeforeName,
2608-
const llvm::StringRef padAfterName) -> LogicalResult {
2609-
if (inputSize == ShapedType::kDynamic ||
2610-
kernelSize == ShapedType::kDynamic)
2611-
return success();
2612-
2613-
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
2614-
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2615-
stride);
2616-
if (!calculatedOutSizeMinusOne.has_value())
2617-
return emitOpError("expected input_")
2618-
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2619-
<< padAfterName << " - (kernel_" << dimName
2620-
<< " - 1) * dilation_" << dimAxis
2621-
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
2622-
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
2623-
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
2624-
<< stride;
2625-
2626-
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2627-
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2628-
return emitOpError("calculated output ")
2629-
<< dimName << " did not match expected: "
2630-
<< "calculated=" << calculatedOutSize
2631-
<< ", expected=" << outputSize;
2632-
2633-
return success();
2634-
};
2635-
2636-
if (failed(verifyOutputSize(
2637-
inputType.getDimSize(1), weightType.getDimSize(1),
2638-
outputType.getDimSize(1), padding[0], padding[1], strides[0],
2639-
dilations[0], "height", "y", "top", "bottom")))
2640-
return failure();
2641-
2642-
if (failed(verifyOutputSize(
2643-
inputType.getDimSize(2), weightType.getDimSize(2),
2644-
outputType.getDimSize(2), padding[2], padding[3], strides[1],
2645-
dilations[1], "width", "x", "left", "right")))
2646-
return failure();
2647-
}
2648-
2649-
const RankedTensorType biasType =
2650-
llvm::dyn_cast<RankedTensorType>(getBias().getType());
2651-
if (!biasType)
2652-
// Skip following checks if bias is not ranked
2653-
return success();
2654-
2655-
const int64_t biasChannels = biasType.getDimSize(0);
2656-
const int64_t outputChannels = outputType.getDimSize(3);
2657-
if (biasChannels == ShapedType::kDynamic ||
2658-
outputChannels == ShapedType::kDynamic)
2659-
// Skip following checks if biasChannels or outputChannels is dynamic dim
2660-
return success();
2661-
2662-
if (biasChannels != outputChannels && biasChannels != 1)
2663-
return emitOpError(
2664-
"bias channels expected to be equal to output channels (")
2665-
<< outputChannels << ") or 1, got " << biasChannels;
26662722
return success();
26672723
}
26682724

@@ -2737,7 +2793,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
27372793
}
27382794

27392795
LogicalResult Conv3DOp::verify() {
2740-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2796+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2797+
verifyConvOpErrorIf(*this).failed())
27412798
return failure();
27422799
return success();
27432800
}
@@ -2847,7 +2904,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
28472904
}
28482905

28492906
LogicalResult DepthwiseConv2DOp::verify() {
2850-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2907+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2908+
verifyConvOpErrorIf(*this).failed())
28512909
return failure();
28522910
return success();
28532911
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -878,22 +878,22 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
878878
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
879879

880880
// CHECK-LABEL: @conv3d_f32
881-
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
882-
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
883-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
881+
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<43x3x4x5x27xf32>, %bias: tensor<43xf32>) -> () {
882+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xf32>) permutation = [1, 2, 3, 4, 0]
883+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xf32>
884884
// CHECK: %[[BROADCAST:.+]] = linalg.generic
885885
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
886-
// CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x47x45x43x28xf32>) {
886+
// CHECK-SAME: ins(%arg2 : tensor<43xf32>) outs(%[[INIT]] : tensor<1x47x45x43x43xf32>) {
887887
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
888888
// CHECK: linalg.yield %[[IN]] : f32
889-
// CHECK: } -> tensor<1x47x45x43x28xf32>
889+
// CHECK: } -> tensor<1x47x45x43x43xf32>
890890
// CHECK: linalg.conv_3d_ndhwc_dhwcf
891891
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
892-
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
893-
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
892+
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x43xf32>)
893+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
894894
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
895895
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
896-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
896+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<43x3x4x5x27xf32>, tensor<43xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x43xf32>
897897
return
898898
}
899899

@@ -919,40 +919,40 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
919919
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
920920

921921
// CHECK-LABEL: @conv3d_i8
922-
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
923-
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
924-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
922+
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<43x3x4x5x27xi8>, %bias: tensor<43xi32>) -> () {
923+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xi8>) permutation = [1, 2, 3, 4, 0]
924+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xi32>
925925
// CHECK: %[[BROADCAST:.+]] = linalg.generic
926926
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
927-
// CHECK-SAME: ins(%arg2 : tensor<28xi32>)
928-
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
927+
// CHECK-SAME: ins(%arg2 : tensor<43xi32>)
928+
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x43xi32>) {
929929
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
930930
// CHECK: linalg.yield %[[IN]] : i32
931-
// CHECK: } -> tensor<1x47x45x43x28xi32>
931+
// CHECK: } -> tensor<1x47x45x43x43xi32>
932932
// CHECK: %[[IZP:.+]] = arith.constant -128 : i32
933933
// CHECK: %[[FZP:.+]] = arith.constant 42 : i32
934934
// CHECK: linalg.conv_3d_ndhwc_dhwcf_q
935935
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
936-
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
937-
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
936+
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x43xi8>, i32, i32)
937+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xi32>) -> tensor<1x47x45x43x43xi32>
938938

939939
%input_zp = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
940940
%weight_zp = "tosa.const"() <{values = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
941-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32>
941+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<43x3x4x5x27xi8>, tensor<43xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x43xi32>
942942
return
943943
}
944944

945945
// -----
946946

947947
// CHECK-LABEL: @conv3d_f16_f32_acc
948-
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
948+
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf16>) -> () {
949949
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
950950
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
951-
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
951+
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
952952
// CHECK: arith.extf %{{.*}} : f16 to f32
953-
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
954-
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
955-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
953+
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
954+
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
955+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf16>
956956
return
957957
}
958958

0 commit comments

Comments
 (0)