@@ -428,6 +428,150 @@ static LogicalResult verifyConvOpModes(T op) {
428
428
return success ();
429
429
}
430
430
431
+ // ===----------------------------------------------------------------------===//
432
+ // ERROR_IF functions.
433
+ // ERROR_IF is a predicate that must set an error if the condition holds.
434
+ // ===----------------------------------------------------------------------===//
435
+
436
+ template <typename T>
437
+ static LogicalResult verifyConvOpErrorIf (T op) {
438
+ llvm::ArrayRef<int64_t > padding = op.getPad ();
439
+ if (llvm::any_of (padding, [](int64_t p) { return p < 0 ; }))
440
+ return op.emitOpError (" expect all padding values to be >= 0, got " )
441
+ << padding;
442
+
443
+ llvm::ArrayRef<int64_t > strides = op.getStride ();
444
+ if (llvm::any_of (strides, [](int64_t s) { return s < 1 ; }))
445
+ return op.emitOpError (" expect all stride values to be >= 1, got " )
446
+ << strides;
447
+
448
+ llvm::ArrayRef<int64_t > dilations = op.getDilation ();
449
+ if (llvm::any_of (dilations, [](int64_t d) { return d < 1 ; }))
450
+ return op.emitOpError (" expect all dilation values to be >= 1, got " )
451
+ << dilations;
452
+
453
+ const RankedTensorType outputType =
454
+ llvm::dyn_cast<RankedTensorType>(op.getOutput ().getType ());
455
+ if (!outputType)
456
+ // Skip following checks if output is not ranked
457
+ return success ();
458
+
459
+ const RankedTensorType inputType =
460
+ llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
461
+ const RankedTensorType weightType =
462
+ llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
463
+
464
+ if (inputType && weightType) {
465
+ const auto verifyOutputSize =
466
+ [&op](const int64_t inputSize, const int64_t kernelSize,
467
+ const int64_t outputSize, const int64_t padBefore,
468
+ const int64_t padAfter, const int64_t stride,
469
+ const int64_t dilation, const llvm::StringRef dimName,
470
+ const llvm::StringRef dimAxis,
471
+ const llvm::StringRef padBeforeName,
472
+ const llvm::StringRef padAfterName) -> LogicalResult {
473
+ if (inputSize == ShapedType::kDynamic ||
474
+ kernelSize == ShapedType::kDynamic )
475
+ return success ();
476
+
477
+ // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
478
+
479
+ const std::optional<int64_t > calculatedOutSizeMinusOne = idivCheck (
480
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1 ) * dilation,
481
+ stride);
482
+ if (!calculatedOutSizeMinusOne.has_value ())
483
+ return op.emitOpError (" expected input_" )
484
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
485
+ << padAfterName << " - (kernel_" << dimName
486
+ << " - 1) * dilation_" << dimAxis
487
+ << " to be wholly divisible by stride_" << dimAxis << " , got ("
488
+ << inputSize << " - 1 + " << padBefore << " + " << padAfter
489
+ << " - (" << kernelSize << " - 1) * " << dilation << " ) / "
490
+ << stride;
491
+
492
+ const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value () + 1 ;
493
+ if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
494
+ return op.emitOpError (" calculated output " )
495
+ << dimName << " did not match expected: "
496
+ << " calculated=" << calculatedOutSize
497
+ << " , expected=" << outputSize;
498
+
499
+ return success ();
500
+ };
501
+
502
+ // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
503
+ if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
504
+ if (failed (verifyOutputSize (
505
+ inputType.getDimSize (1 ), weightType.getDimSize (1 ),
506
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
507
+ dilations[0 ], " height" , " y" , " top" , " bottom" )))
508
+ return failure ();
509
+
510
+ if (failed (verifyOutputSize (
511
+ inputType.getDimSize (2 ), weightType.getDimSize (2 ),
512
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
513
+ dilations[1 ], " width" , " x" , " left" , " right" )))
514
+ return failure ();
515
+ }
516
+
517
+ // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
518
+ if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
519
+ if (failed (verifyOutputSize (
520
+ inputType.getDimSize (1 ), weightType.getDimSize (0 ),
521
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
522
+ dilations[0 ], " height" , " y" , " top" , " bottom" )))
523
+ return failure ();
524
+
525
+ if (failed (verifyOutputSize (
526
+ inputType.getDimSize (2 ), weightType.getDimSize (1 ),
527
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
528
+ dilations[1 ], " width" , " x" , " left" , " right" )))
529
+ return failure ();
530
+ }
531
+
532
+ // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
533
+ if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
534
+ if (failed (verifyOutputSize (
535
+ inputType.getDimSize (1 ), weightType.getDimSize (1 ),
536
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
537
+ dilations[0 ], " depth" , " d" , " front" , " back" )))
538
+ return failure ();
539
+
540
+ if (failed (verifyOutputSize (
541
+ inputType.getDimSize (2 ), weightType.getDimSize (2 ),
542
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
543
+ dilations[1 ], " height" , " y" , " top" , " bottom" )))
544
+ return failure ();
545
+
546
+ if (failed (verifyOutputSize (
547
+ inputType.getDimSize (3 ), weightType.getDimSize (3 ),
548
+ outputType.getDimSize (3 ), padding[4 ], padding[5 ], strides[2 ],
549
+ dilations[2 ], " width" , " x" , " left" , " right" )))
550
+ return failure ();
551
+ }
552
+ }
553
+
554
+ const RankedTensorType biasType =
555
+ llvm::dyn_cast<RankedTensorType>(op.getBias ().getType ());
556
+ if (!biasType)
557
+ // Skip following checks if bias is not ranked
558
+ return success ();
559
+
560
+ const int64_t biasChannels = biasType.getDimSize (0 );
561
+ const int64_t outputChannels = outputType.getDimSize (3 );
562
+ if (biasChannels == ShapedType::kDynamic ||
563
+ outputChannels == ShapedType::kDynamic )
564
+ // Skip following checks if biasChannels or outputChannels is dynamic dim
565
+ return success ();
566
+
567
+ if (biasChannels != outputChannels && biasChannels != 1 )
568
+ return op.emitOpError (
569
+ " bias channels expected to be equal to output channels (" )
570
+ << outputChannels << " ) or 1, got " << biasChannels;
571
+
572
+ return success ();
573
+ }
574
+
431
575
// verify that inType and outType have same element types
432
576
template <typename T>
433
577
static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -2767,99 +2911,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
2767
2911
}
2768
2912
2769
2913
LogicalResult Conv2DOp::verify () {
2770
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2914
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2915
+ verifyConvOpErrorIf (*this ).failed ())
2771
2916
return failure ();
2772
-
2773
- llvm::ArrayRef<int64_t > padding = getPad ();
2774
- if (llvm::any_of (padding, [](int64_t p) { return p < 0 ; }))
2775
- return emitOpError (" expect all padding values to be >= 0, got " ) << padding;
2776
-
2777
- llvm::ArrayRef<int64_t > strides = getStride ();
2778
- if (llvm::any_of (strides, [](int64_t s) { return s < 1 ; }))
2779
- return emitOpError (" expect all stride values to be >= 1, got " ) << strides;
2780
-
2781
- llvm::ArrayRef<int64_t > dilations = getDilation ();
2782
- if (llvm::any_of (dilations, [](int64_t d) { return d < 1 ; }))
2783
- return emitOpError (" expect all dilation values to be >= 1, got " )
2784
- << dilations;
2785
-
2786
- const RankedTensorType outputType =
2787
- llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
2788
- if (!outputType)
2789
- // Skip following checks if output is not ranked
2790
- return success ();
2791
-
2792
- const RankedTensorType inputType =
2793
- llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
2794
- const RankedTensorType weightType =
2795
- llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
2796
-
2797
- if (inputType && weightType) {
2798
- const auto verifyOutputSize =
2799
- [this ](const int64_t inputSize, const int64_t kernelSize,
2800
- const int64_t outputSize, const int64_t padBefore,
2801
- const int64_t padAfter, const int64_t stride,
2802
- const int64_t dilation, const llvm::StringRef dimName,
2803
- const llvm::StringRef dimAxis,
2804
- const llvm::StringRef padBeforeName,
2805
- const llvm::StringRef padAfterName) -> LogicalResult {
2806
- if (inputSize == ShapedType::kDynamic ||
2807
- kernelSize == ShapedType::kDynamic )
2808
- return success ();
2809
-
2810
- const std::optional<int64_t > calculatedOutSizeMinusOne = idivCheck (
2811
- inputSize - 1 + padBefore + padAfter - (kernelSize - 1 ) * dilation,
2812
- stride);
2813
- if (!calculatedOutSizeMinusOne.has_value ())
2814
- return emitOpError (" expected input_" )
2815
- << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2816
- << padAfterName << " - (kernel_" << dimName
2817
- << " - 1) * dilation_" << dimAxis
2818
- << " to be wholly divisible by stride_" << dimAxis << " , got ("
2819
- << inputSize << " - 1 + " << padBefore << " + " << padAfter
2820
- << " - (" << kernelSize << " - 1) * " << dilation << " ) / "
2821
- << stride;
2822
-
2823
- const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value () + 1 ;
2824
- if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2825
- return emitOpError (" calculated output " )
2826
- << dimName << " did not match expected: "
2827
- << " calculated=" << calculatedOutSize
2828
- << " , expected=" << outputSize;
2829
-
2830
- return success ();
2831
- };
2832
-
2833
- if (failed (verifyOutputSize (
2834
- inputType.getDimSize (1 ), weightType.getDimSize (1 ),
2835
- outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
2836
- dilations[0 ], " height" , " y" , " top" , " bottom" )))
2837
- return failure ();
2838
-
2839
- if (failed (verifyOutputSize (
2840
- inputType.getDimSize (2 ), weightType.getDimSize (2 ),
2841
- outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
2842
- dilations[1 ], " width" , " x" , " left" , " right" )))
2843
- return failure ();
2844
- }
2845
-
2846
- const RankedTensorType biasType =
2847
- llvm::dyn_cast<RankedTensorType>(getBias ().getType ());
2848
- if (!biasType)
2849
- // Skip following checks if bias is not ranked
2850
- return success ();
2851
-
2852
- const int64_t biasChannels = biasType.getDimSize (0 );
2853
- const int64_t outputChannels = outputType.getDimSize (3 );
2854
- if (biasChannels == ShapedType::kDynamic ||
2855
- outputChannels == ShapedType::kDynamic )
2856
- // Skip following checks if biasChannels or outputChannels is dynamic dim
2857
- return success ();
2858
-
2859
- if (biasChannels != outputChannels && biasChannels != 1 )
2860
- return emitOpError (
2861
- " bias channels expected to be equal to output channels (" )
2862
- << outputChannels << " ) or 1, got " << biasChannels;
2863
2917
return success ();
2864
2918
}
2865
2919
@@ -2934,7 +2988,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
2934
2988
}
2935
2989
2936
2990
LogicalResult Conv3DOp::verify () {
2937
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2991
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2992
+ verifyConvOpErrorIf (*this ).failed ())
2938
2993
return failure ();
2939
2994
return success ();
2940
2995
}
@@ -3044,7 +3099,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3044
3099
}
3045
3100
3046
3101
LogicalResult DepthwiseConv2DOp::verify () {
3047
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
3102
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
3103
+ verifyConvOpErrorIf (*this ).failed ())
3048
3104
return failure ();
3049
3105
return success ();
3050
3106
}
0 commit comments