@@ -428,6 +428,150 @@ static LogicalResult verifyConvOpModes(T op) {
428
428
return success ();
429
429
}
430
430
431
+ // ===----------------------------------------------------------------------===//
432
+ // ERROR_IF functions.
433
+ // ERROR_IF is a predicate that must set an error if the condition holds.
434
+ // ===----------------------------------------------------------------------===//
435
+
436
+ template <typename T>
437
+ static LogicalResult verifyConvOpErrorIf (T op) {
438
+ llvm::ArrayRef<int64_t > padding = op.getPad ();
439
+ if (llvm::any_of (padding, [](int64_t p) { return p < 0 ; }))
440
+ return op.emitOpError (" expect all padding values to be >= 0, got " )
441
+ << padding;
442
+
443
+ llvm::ArrayRef<int64_t > strides = op.getStride ();
444
+ if (llvm::any_of (strides, [](int64_t s) { return s < 1 ; }))
445
+ return op.emitOpError (" expect all stride values to be >= 1, got " )
446
+ << strides;
447
+
448
+ llvm::ArrayRef<int64_t > dilations = op.getDilation ();
449
+ if (llvm::any_of (dilations, [](int64_t d) { return d < 1 ; }))
450
+ return op.emitOpError (" expect all dilation values to be >= 1, got " )
451
+ << dilations;
452
+
453
+ const RankedTensorType outputType =
454
+ llvm::dyn_cast<RankedTensorType>(op.getOutput ().getType ());
455
+ if (!outputType)
456
+ // Skip following checks if output is not ranked
457
+ return success ();
458
+
459
+ const RankedTensorType inputType =
460
+ llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
461
+ const RankedTensorType weightType =
462
+ llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
463
+
464
+ if (inputType && weightType) {
465
+ const auto verifyOutputSize =
466
+ [&op](const int64_t inputSize, const int64_t kernelSize,
467
+ const int64_t outputSize, const int64_t padBefore,
468
+ const int64_t padAfter, const int64_t stride,
469
+ const int64_t dilation, const llvm::StringRef dimName,
470
+ const llvm::StringRef dimAxis,
471
+ const llvm::StringRef padBeforeName,
472
+ const llvm::StringRef padAfterName) -> LogicalResult {
473
+ if (inputSize == ShapedType::kDynamic ||
474
+ kernelSize == ShapedType::kDynamic )
475
+ return success ();
476
+
477
+ // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
478
+
479
+ const std::optional<int64_t > calculatedOutSizeMinusOne = idivCheck (
480
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1 ) * dilation,
481
+ stride);
482
+ if (!calculatedOutSizeMinusOne.has_value ())
483
+ return op.emitOpError (" expected input_" )
484
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
485
+ << padAfterName << " - (kernel_" << dimName
486
+ << " - 1) * dilation_" << dimAxis
487
+ << " to be wholly divisible by stride_" << dimAxis << " , got ("
488
+ << inputSize << " - 1 + " << padBefore << " + " << padAfter
489
+ << " - (" << kernelSize << " - 1) * " << dilation << " ) / "
490
+ << stride;
491
+
492
+ const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value () + 1 ;
493
+ if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
494
+ return op.emitOpError (" calculated output " )
495
+ << dimName << " did not match expected: "
496
+ << " calculated=" << calculatedOutSize
497
+ << " , expected=" << outputSize;
498
+
499
+ return success ();
500
+ };
501
+
502
+ // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
503
+ if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
504
+ if (failed (verifyOutputSize (
505
+ inputType.getDimSize (1 ), weightType.getDimSize (1 ),
506
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
507
+ dilations[0 ], " height" , " y" , " top" , " bottom" )))
508
+ return failure ();
509
+
510
+ if (failed (verifyOutputSize (
511
+ inputType.getDimSize (2 ), weightType.getDimSize (2 ),
512
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
513
+ dilations[1 ], " width" , " x" , " left" , " right" )))
514
+ return failure ();
515
+ }
516
+
517
+ // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
518
+ if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
519
+ if (failed (verifyOutputSize (
520
+ inputType.getDimSize (1 ), weightType.getDimSize (0 ),
521
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
522
+ dilations[0 ], " height" , " y" , " top" , " bottom" )))
523
+ return failure ();
524
+
525
+ if (failed (verifyOutputSize (
526
+ inputType.getDimSize (2 ), weightType.getDimSize (1 ),
527
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
528
+ dilations[1 ], " width" , " x" , " left" , " right" )))
529
+ return failure ();
530
+ }
531
+
532
+ // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
533
+ if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
534
+ if (failed (verifyOutputSize (
535
+ inputType.getDimSize (1 ), weightType.getDimSize (1 ),
536
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
537
+ dilations[0 ], " depth" , " d" , " front" , " back" )))
538
+ return failure ();
539
+
540
+ if (failed (verifyOutputSize (
541
+ inputType.getDimSize (2 ), weightType.getDimSize (2 ),
542
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
543
+ dilations[1 ], " height" , " y" , " top" , " bottom" )))
544
+ return failure ();
545
+
546
+ if (failed (verifyOutputSize (
547
+ inputType.getDimSize (3 ), weightType.getDimSize (3 ),
548
+ outputType.getDimSize (3 ), padding[4 ], padding[5 ], strides[2 ],
549
+ dilations[2 ], " width" , " x" , " left" , " right" )))
550
+ return failure ();
551
+ }
552
+ }
553
+
554
+ const RankedTensorType biasType =
555
+ llvm::dyn_cast<RankedTensorType>(op.getBias ().getType ());
556
+ if (!biasType)
557
+ // Skip following checks if bias is not ranked
558
+ return success ();
559
+
560
+ const int64_t biasChannels = biasType.getDimSize (0 );
561
+ const int64_t outputChannels = outputType.getDimSize (3 );
562
+ if (biasChannels == ShapedType::kDynamic ||
563
+ outputChannels == ShapedType::kDynamic )
564
+ // Skip following checks if biasChannels or outputChannels is dynamic dim
565
+ return success ();
566
+
567
+ if (biasChannels != outputChannels && biasChannels != 1 )
568
+ return op.emitOpError (
569
+ " bias channels expected to be equal to output channels (" )
570
+ << outputChannels << " ) or 1, got " << biasChannels;
571
+
572
+ return success ();
573
+ }
574
+
431
575
// verify that inType and outType have same element types
432
576
template <typename T>
433
577
static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -2586,99 +2730,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
2586
2730
}
2587
2731
2588
2732
LogicalResult Conv2DOp::verify () {
2589
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2733
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2734
+ verifyConvOpErrorIf (*this ).failed ())
2590
2735
return failure ();
2591
-
2592
- llvm::ArrayRef<int64_t > padding = getPad ();
2593
- if (llvm::any_of (padding, [](int64_t p) { return p < 0 ; }))
2594
- return emitOpError (" expect all padding values to be >= 0, got " ) << padding;
2595
-
2596
- llvm::ArrayRef<int64_t > strides = getStride ();
2597
- if (llvm::any_of (strides, [](int64_t s) { return s < 1 ; }))
2598
- return emitOpError (" expect all stride values to be >= 1, got " ) << strides;
2599
-
2600
- llvm::ArrayRef<int64_t > dilations = getDilation ();
2601
- if (llvm::any_of (dilations, [](int64_t d) { return d < 1 ; }))
2602
- return emitOpError (" expect all dilation values to be >= 1, got " )
2603
- << dilations;
2604
-
2605
- const RankedTensorType outputType =
2606
- llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
2607
- if (!outputType)
2608
- // Skip following checks if output is not ranked
2609
- return success ();
2610
-
2611
- const RankedTensorType inputType =
2612
- llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
2613
- const RankedTensorType weightType =
2614
- llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
2615
-
2616
- if (inputType && weightType) {
2617
- const auto verifyOutputSize =
2618
- [this ](const int64_t inputSize, const int64_t kernelSize,
2619
- const int64_t outputSize, const int64_t padBefore,
2620
- const int64_t padAfter, const int64_t stride,
2621
- const int64_t dilation, const llvm::StringRef dimName,
2622
- const llvm::StringRef dimAxis,
2623
- const llvm::StringRef padBeforeName,
2624
- const llvm::StringRef padAfterName) -> LogicalResult {
2625
- if (inputSize == ShapedType::kDynamic ||
2626
- kernelSize == ShapedType::kDynamic )
2627
- return success ();
2628
-
2629
- const std::optional<int64_t > calculatedOutSizeMinusOne = idivCheck (
2630
- inputSize - 1 + padBefore + padAfter - (kernelSize - 1 ) * dilation,
2631
- stride);
2632
- if (!calculatedOutSizeMinusOne.has_value ())
2633
- return emitOpError (" expected input_" )
2634
- << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2635
- << padAfterName << " - (kernel_" << dimName
2636
- << " - 1) * dilation_" << dimAxis
2637
- << " to be wholly divisible by stride_" << dimAxis << " , got ("
2638
- << inputSize << " - 1 + " << padBefore << " + " << padAfter
2639
- << " - (" << kernelSize << " - 1) * " << dilation << " ) / "
2640
- << stride;
2641
-
2642
- const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value () + 1 ;
2643
- if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2644
- return emitOpError (" calculated output " )
2645
- << dimName << " did not match expected: "
2646
- << " calculated=" << calculatedOutSize
2647
- << " , expected=" << outputSize;
2648
-
2649
- return success ();
2650
- };
2651
-
2652
- if (failed (verifyOutputSize (
2653
- inputType.getDimSize (1 ), weightType.getDimSize (1 ),
2654
- outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
2655
- dilations[0 ], " height" , " y" , " top" , " bottom" )))
2656
- return failure ();
2657
-
2658
- if (failed (verifyOutputSize (
2659
- inputType.getDimSize (2 ), weightType.getDimSize (2 ),
2660
- outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
2661
- dilations[1 ], " width" , " x" , " left" , " right" )))
2662
- return failure ();
2663
- }
2664
-
2665
- const RankedTensorType biasType =
2666
- llvm::dyn_cast<RankedTensorType>(getBias ().getType ());
2667
- if (!biasType)
2668
- // Skip following checks if bias is not ranked
2669
- return success ();
2670
-
2671
- const int64_t biasChannels = biasType.getDimSize (0 );
2672
- const int64_t outputChannels = outputType.getDimSize (3 );
2673
- if (biasChannels == ShapedType::kDynamic ||
2674
- outputChannels == ShapedType::kDynamic )
2675
- // Skip following checks if biasChannels or outputChannels is dynamic dim
2676
- return success ();
2677
-
2678
- if (biasChannels != outputChannels && biasChannels != 1 )
2679
- return emitOpError (
2680
- " bias channels expected to be equal to output channels (" )
2681
- << outputChannels << " ) or 1, got " << biasChannels;
2682
2736
return success ();
2683
2737
}
2684
2738
@@ -2753,7 +2807,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
2753
2807
}
2754
2808
2755
2809
LogicalResult Conv3DOp::verify () {
2756
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2810
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2811
+ verifyConvOpErrorIf (*this ).failed ())
2757
2812
return failure ();
2758
2813
return success ();
2759
2814
}
@@ -2863,7 +2918,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2863
2918
}
2864
2919
2865
2920
LogicalResult DepthwiseConv2DOp::verify () {
2866
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2921
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2922
+ verifyConvOpErrorIf (*this ).failed ())
2867
2923
return failure ();
2868
2924
return success ();
2869
2925
}
0 commit comments