Skip to content

Commit e2a9902

Browse files
committed
[mlir][tosa] Enhance CONV3D & DEPTHWISE_CONV2D verifier
Verify the pad, stride, dilation, and dimension of input/output.
1 parent 6727d58 commit e2a9902

File tree

11 files changed

+425
-218
lines changed

11 files changed

+425
-218
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 150 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,150 @@ static LogicalResult verifyConvOpModes(T op) {
428428
return success();
429429
}
430430

431+
//===----------------------------------------------------------------------===//
432+
// ERROR_IF functions.
433+
// ERROR_IF is a predicate that must set an error if the condition holds.
434+
//===----------------------------------------------------------------------===//
435+
436+
template <typename T>
437+
static LogicalResult verifyConvOpErrorIf(T op) {
438+
llvm::ArrayRef<int64_t> padding = op.getPad();
439+
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
440+
return op.emitOpError("expect all padding values to be >= 0, got ")
441+
<< padding;
442+
443+
llvm::ArrayRef<int64_t> strides = op.getStride();
444+
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
445+
return op.emitOpError("expect all stride values to be >= 1, got ")
446+
<< strides;
447+
448+
llvm::ArrayRef<int64_t> dilations = op.getDilation();
449+
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
450+
return op.emitOpError("expect all dilation values to be >= 1, got ")
451+
<< dilations;
452+
453+
const RankedTensorType outputType =
454+
llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
455+
if (!outputType)
456+
// Skip following checks if output is not ranked
457+
return success();
458+
459+
const RankedTensorType inputType =
460+
llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
461+
const RankedTensorType weightType =
462+
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
463+
464+
if (inputType && weightType) {
465+
const auto verifyOutputSize =
466+
[&op](const int64_t inputSize, const int64_t kernelSize,
467+
const int64_t outputSize, const int64_t padBefore,
468+
const int64_t padAfter, const int64_t stride,
469+
const int64_t dilation, const llvm::StringRef dimName,
470+
const llvm::StringRef dimAxis,
471+
const llvm::StringRef padBeforeName,
472+
const llvm::StringRef padAfterName) -> LogicalResult {
473+
if (inputSize == ShapedType::kDynamic ||
474+
kernelSize == ShapedType::kDynamic)
475+
return success();
476+
477+
// ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
478+
479+
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
480+
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
481+
stride);
482+
if (!calculatedOutSizeMinusOne.has_value())
483+
return op.emitOpError("expected input_")
484+
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
485+
<< padAfterName << " - (kernel_" << dimName
486+
<< " - 1) * dilation_" << dimAxis
487+
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
488+
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
489+
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
490+
<< stride;
491+
492+
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
493+
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
494+
return op.emitOpError("calculated output ")
495+
<< dimName << " did not match expected: "
496+
<< "calculated=" << calculatedOutSize
497+
<< ", expected=" << outputSize;
498+
499+
return success();
500+
};
501+
502+
// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
503+
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
504+
if (failed(verifyOutputSize(
505+
inputType.getDimSize(1), weightType.getDimSize(1),
506+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
507+
dilations[0], "height", "y", "top", "bottom")))
508+
return failure();
509+
510+
if (failed(verifyOutputSize(
511+
inputType.getDimSize(2), weightType.getDimSize(2),
512+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
513+
dilations[1], "width", "x", "left", "right")))
514+
return failure();
515+
}
516+
517+
// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
518+
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
519+
if (failed(verifyOutputSize(
520+
inputType.getDimSize(1), weightType.getDimSize(0),
521+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
522+
dilations[0], "height", "y", "top", "bottom")))
523+
return failure();
524+
525+
if (failed(verifyOutputSize(
526+
inputType.getDimSize(2), weightType.getDimSize(1),
527+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
528+
dilations[1], "width", "x", "left", "right")))
529+
return failure();
530+
}
531+
532+
// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
533+
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
534+
if (failed(verifyOutputSize(
535+
inputType.getDimSize(1), weightType.getDimSize(1),
536+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
537+
dilations[0], "depth", "d", "front", "back")))
538+
return failure();
539+
540+
if (failed(verifyOutputSize(
541+
inputType.getDimSize(2), weightType.getDimSize(2),
542+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
543+
dilations[1], "height", "y", "top", "bottom")))
544+
return failure();
545+
546+
if (failed(verifyOutputSize(
547+
inputType.getDimSize(3), weightType.getDimSize(3),
548+
outputType.getDimSize(3), padding[4], padding[5], strides[2],
549+
dilations[2], "width", "x", "left", "right")))
550+
return failure();
551+
}
552+
}
553+
554+
const RankedTensorType biasType =
555+
llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
556+
if (!biasType)
557+
// Skip following checks if bias is not ranked
558+
return success();
559+
560+
const int64_t biasChannels = biasType.getDimSize(0);
561+
const int64_t outputChannels = outputType.getDimSize(3);
562+
if (biasChannels == ShapedType::kDynamic ||
563+
outputChannels == ShapedType::kDynamic)
564+
// Skip following checks if biasChannels or outputChannels is dynamic dim
565+
return success();
566+
567+
if (biasChannels != outputChannels && biasChannels != 1)
568+
return op.emitOpError(
569+
"bias channels expected to be equal to output channels (")
570+
<< outputChannels << ") or 1, got " << biasChannels;
571+
572+
return success();
573+
}
574+
431575
// verify that inType and outType have same element types
432576
template <typename T>
433577
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -2586,99 +2730,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
25862730
}
25872731

25882732
LogicalResult Conv2DOp::verify() {
2589-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2733+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2734+
verifyConvOpErrorIf(*this).failed())
25902735
return failure();
2591-
2592-
llvm::ArrayRef<int64_t> padding = getPad();
2593-
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
2594-
return emitOpError("expect all padding values to be >= 0, got ") << padding;
2595-
2596-
llvm::ArrayRef<int64_t> strides = getStride();
2597-
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
2598-
return emitOpError("expect all stride values to be >= 1, got ") << strides;
2599-
2600-
llvm::ArrayRef<int64_t> dilations = getDilation();
2601-
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
2602-
return emitOpError("expect all dilation values to be >= 1, got ")
2603-
<< dilations;
2604-
2605-
const RankedTensorType outputType =
2606-
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2607-
if (!outputType)
2608-
// Skip following checks if output is not ranked
2609-
return success();
2610-
2611-
const RankedTensorType inputType =
2612-
llvm::dyn_cast<RankedTensorType>(getInput().getType());
2613-
const RankedTensorType weightType =
2614-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2615-
2616-
if (inputType && weightType) {
2617-
const auto verifyOutputSize =
2618-
[this](const int64_t inputSize, const int64_t kernelSize,
2619-
const int64_t outputSize, const int64_t padBefore,
2620-
const int64_t padAfter, const int64_t stride,
2621-
const int64_t dilation, const llvm::StringRef dimName,
2622-
const llvm::StringRef dimAxis,
2623-
const llvm::StringRef padBeforeName,
2624-
const llvm::StringRef padAfterName) -> LogicalResult {
2625-
if (inputSize == ShapedType::kDynamic ||
2626-
kernelSize == ShapedType::kDynamic)
2627-
return success();
2628-
2629-
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
2630-
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2631-
stride);
2632-
if (!calculatedOutSizeMinusOne.has_value())
2633-
return emitOpError("expected input_")
2634-
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2635-
<< padAfterName << " - (kernel_" << dimName
2636-
<< " - 1) * dilation_" << dimAxis
2637-
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
2638-
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
2639-
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
2640-
<< stride;
2641-
2642-
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2643-
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2644-
return emitOpError("calculated output ")
2645-
<< dimName << " did not match expected: "
2646-
<< "calculated=" << calculatedOutSize
2647-
<< ", expected=" << outputSize;
2648-
2649-
return success();
2650-
};
2651-
2652-
if (failed(verifyOutputSize(
2653-
inputType.getDimSize(1), weightType.getDimSize(1),
2654-
outputType.getDimSize(1), padding[0], padding[1], strides[0],
2655-
dilations[0], "height", "y", "top", "bottom")))
2656-
return failure();
2657-
2658-
if (failed(verifyOutputSize(
2659-
inputType.getDimSize(2), weightType.getDimSize(2),
2660-
outputType.getDimSize(2), padding[2], padding[3], strides[1],
2661-
dilations[1], "width", "x", "left", "right")))
2662-
return failure();
2663-
}
2664-
2665-
const RankedTensorType biasType =
2666-
llvm::dyn_cast<RankedTensorType>(getBias().getType());
2667-
if (!biasType)
2668-
// Skip following checks if bias is not ranked
2669-
return success();
2670-
2671-
const int64_t biasChannels = biasType.getDimSize(0);
2672-
const int64_t outputChannels = outputType.getDimSize(3);
2673-
if (biasChannels == ShapedType::kDynamic ||
2674-
outputChannels == ShapedType::kDynamic)
2675-
// Skip following checks if biasChannels or outputChannels is dynamic dim
2676-
return success();
2677-
2678-
if (biasChannels != outputChannels && biasChannels != 1)
2679-
return emitOpError(
2680-
"bias channels expected to be equal to output channels (")
2681-
<< outputChannels << ") or 1, got " << biasChannels;
26822736
return success();
26832737
}
26842738

@@ -2753,7 +2807,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
27532807
}
27542808

27552809
LogicalResult Conv3DOp::verify() {
2756-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2810+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2811+
verifyConvOpErrorIf(*this).failed())
27572812
return failure();
27582813
return success();
27592814
}
@@ -2863,7 +2918,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
28632918
}
28642919

28652920
LogicalResult DepthwiseConv2DOp::verify() {
2866-
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2921+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2922+
verifyConvOpErrorIf(*this).failed())
28672923
return failure();
28682924
return success();
28692925
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -878,22 +878,22 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
878878
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
879879

880880
// CHECK-LABEL: @conv3d_f32
881-
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
882-
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
883-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
881+
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<43x3x4x5x27xf32>, %bias: tensor<43xf32>) -> () {
882+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xf32>) permutation = [1, 2, 3, 4, 0]
883+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xf32>
884884
// CHECK: %[[BROADCAST:.+]] = linalg.generic
885885
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
886-
// CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x47x45x43x28xf32>) {
886+
// CHECK-SAME: ins(%arg2 : tensor<43xf32>) outs(%[[INIT]] : tensor<1x47x45x43x43xf32>) {
887887
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
888888
// CHECK: linalg.yield %[[IN]] : f32
889-
// CHECK: } -> tensor<1x47x45x43x28xf32>
889+
// CHECK: } -> tensor<1x47x45x43x43xf32>
890890
// CHECK: linalg.conv_3d_ndhwc_dhwcf
891891
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
892-
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
893-
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
892+
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x43xf32>)
893+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
894894
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
895895
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
896-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
896+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<43x3x4x5x27xf32>, tensor<43xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x43xf32>
897897
return
898898
}
899899

@@ -919,40 +919,40 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
919919
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
920920

921921
// CHECK-LABEL: @conv3d_i8
922-
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
923-
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
924-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
922+
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<43x3x4x5x27xi8>, %bias: tensor<43xi32>) -> () {
923+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xi8>) permutation = [1, 2, 3, 4, 0]
924+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xi32>
925925
// CHECK: %[[BROADCAST:.+]] = linalg.generic
926926
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
927-
// CHECK-SAME: ins(%arg2 : tensor<28xi32>)
928-
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
927+
// CHECK-SAME: ins(%arg2 : tensor<43xi32>)
928+
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x43xi32>) {
929929
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
930930
// CHECK: linalg.yield %[[IN]] : i32
931-
// CHECK: } -> tensor<1x47x45x43x28xi32>
931+
// CHECK: } -> tensor<1x47x45x43x43xi32>
932932
// CHECK: %[[IZP:.+]] = arith.constant -128 : i32
933933
// CHECK: %[[FZP:.+]] = arith.constant 42 : i32
934934
// CHECK: linalg.conv_3d_ndhwc_dhwcf_q
935935
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
936-
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
937-
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
936+
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x43xi8>, i32, i32)
937+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xi32>) -> tensor<1x47x45x43x43xi32>
938938

939939
%input_zp = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
940940
%weight_zp = "tosa.const"() <{values = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
941-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32>
941+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<43x3x4x5x27xi8>, tensor<43xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x43xi32>
942942
return
943943
}
944944

945945
// -----
946946

947947
// CHECK-LABEL: @conv3d_f16_f32_acc
948-
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
948+
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf16>) -> () {
949949
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
950950
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
951-
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
951+
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
952952
// CHECK: arith.extf %{{.*}} : f16 to f32
953-
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
954-
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
955-
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
953+
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
954+
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
955+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf16>
956956
return
957957
}
958958

0 commit comments

Comments
 (0)