@@ -32106,6 +32106,15 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const
3210632106    return true;
3210732107}
3210832108
32109+ //------------------------------------------------------------------------
32110+ // gtFoldExprHWIntrinsic: Attempt to fold a HWIntrinsic
32111+ //
32112+ // Arguments:
32113+ //    tree - HWIntrinsic to fold
32114+ //
32115+ // Return Value:
32116+ //    folded expression if it could be folded, else the original tree
32117+ //
3210932118#if defined(FEATURE_HW_INTRINSICS)
3211032119GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3211132120{
@@ -32249,7 +32258,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3224932258    // We shouldn't find AND_NOT nodes since it should only be produced in lowering
3225032259    assert(oper != GT_AND_NOT);
3225132260
32252- #if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH)
32261+ #ifdef FEATURE_MASKED_HW_INTRINSICS
32262+ #ifdef TARGET_XARCH
3225332263    if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper))
3225432264    {
3225532265        // Comparisons that produce masks lead to more verbose trees than
@@ -32367,7 +32377,75 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3236732377            }
3236832378        }
3236932379    }
32370- #endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH
32380+ #elif defined(TARGET_ARM64)
32381+     // Check if the tree can be folded into a mask variant
32382+     if (HWIntrinsicInfo::HasAllMaskVariant(tree->GetHWIntrinsicId()))
32383+     {
32384+         NamedIntrinsic maskVariant = HWIntrinsicInfo::GetMaskVariant(tree->GetHWIntrinsicId());
32385+ 
32386+         assert(opCount == (size_t)HWIntrinsicInfo::lookupNumArgs(maskVariant));
32387+ 
32388+         // Check all operands are valid
32389+         bool canFold = true;
32390+         if (ni == NI_Sve_ConditionalSelect)
32391+         {
32392+             assert(varTypeIsMask(op1));
32393+             canFold = (op2->OperIsConvertMaskToVector() && op3->OperIsConvertMaskToVector());
32394+         }
32395+         else
32396+         {
32397+             for (size_t i = 1; i <= opCount && canFold; i++)
32398+             {
32399+                 canFold &= tree->Op(i)->OperIsConvertMaskToVector();
32400+             }
32401+         }
32402+ 
32403+         if (canFold)
32404+         {
32405+             // Convert all the operands to masks
32406+             for (size_t i = 1; i <= opCount; i++)
32407+             {
32408+                 if (tree->Op(i)->OperIsConvertMaskToVector())
32409+                 {
32410+                     // Replace with op1.
32411+                     tree->Op(i) = tree->Op(i)->AsHWIntrinsic()->Op(1);
32412+                 }
32413+                 else if (tree->Op(i)->IsVectorZero())
32414+                 {
32415+                     // Replace the vector of zeroes with a mask of zeroes.
32416+                     tree->Op(i) = gtNewSimdFalseMaskByteNode();
32417+                     tree->Op(i)->SetMorphed(this);
32418+                 }
32419+                 assert(varTypeIsMask(tree->Op(i)));
32420+             }
32421+ 
32422+             // Switch to the mask variant
32423+             switch (opCount)
32424+             {
32425+                 case 1:
32426+                     tree->ResetHWIntrinsicId(maskVariant, tree->Op(1));
32427+                     break;
32428+                 case 2:
32429+                     tree->ResetHWIntrinsicId(maskVariant, tree->Op(1), tree->Op(2));
32430+                     break;
32431+                 case 3:
32432+                     tree->ResetHWIntrinsicId(maskVariant, this, tree->Op(1), tree->Op(2), tree->Op(3));
32433+                     break;
32434+                 default:
32435+                     unreached();
32436+             }
32437+ 
32438+             tree->gtType = TYP_MASK;
32439+             tree->SetMorphed(this);
32440+             tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic();
32441+             tree->SetMorphed(this);
32442+             op1 = tree->Op(1);
32443+             op2 = nullptr;
32444+             op3 = nullptr;
32445+         }
32446+     }
32447+ #endif // TARGET_ARM64
32448+ #endif // FEATURE_MASKED_HW_INTRINSICS
3237132449
3237232450    GenTree* cnsNode   = nullptr;
3237332451    GenTree* otherNode = nullptr;
@@ -33754,7 +33832,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3375433832                    // op2 = op2 & op1
3375533833                    op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon());
3375633834
33757-                     // op3 = op2  & ~op1
33835+                     // op3 = op3  & ~op1
3375833836                    op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon());
3375933837
3376033838                    // op2 = op2 | op3
@@ -33767,8 +33845,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3376733845
3376833846#if defined(TARGET_ARM64)
3376933847            case NI_Sve_ConditionalSelect:
33848+             case NI_Sve_ConditionalSelect_Predicates:
3377033849            {
33771-                 assert(!varTypeIsMask(retType));
3377233850                assert(varTypeIsMask(op1));
3377333851
3377433852                if (cnsNode != op1)
@@ -33797,10 +33875,11 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3379733875
3379833876                if (op2->IsCnsVec() && op3->IsCnsVec())
3379933877                {
33878+                     assert(ni == NI_Sve_ConditionalSelect);
3380033879                    assert(op2->gtType == TYP_SIMD16);
3380133880                    assert(op3->gtType == TYP_SIMD16);
3380233881
33803-                     simd16_t op1SimdVal;
33882+                     simd16_t op1SimdVal = {} ;
3380433883                    EvaluateSimdCvtMaskToVector<simd16_t>(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal);
3380533884
3380633885                    // op2 = op2 & op1
@@ -33809,7 +33888,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3380933888                                                 op1SimdVal);
3381033889                    op2->AsVecCon()->gtSimd16Val = result;
3381133890
33812-                     // op3 = op2  & ~op1
33891+                     // op3 = op3  & ~op1
3381333892                    result = {};
3381433893                    EvaluateBinarySimd<simd16_t>(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val,
3381533894                                                 op1SimdVal);
@@ -33820,6 +33899,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3382033899
3382133900                    resultNode = op2;
3382233901                }
33902+                 else if (op2->IsCnsMsk() && op3->IsCnsMsk())
33903+                 {
33904+                     assert(ni == NI_Sve_ConditionalSelect_Predicates);
33905+ 
33906+                     // op2 = op2 & op1
33907+                     simdmask_t result = {};
33908+                     EvaluateBinaryMask<simd16_t>(GT_AND, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
33909+                                                  op1->AsMskCon()->gtSimdMaskVal);
33910+                     op2->AsMskCon()->gtSimdMaskVal = result;
33911+ 
33912+                     // op3 = op3 & ~op1
33913+                     result = {};
33914+                     EvaluateBinaryMask<simd16_t>(GT_AND_NOT, false, simdBaseType, &result,
33915+                                                  op3->AsMskCon()->gtSimdMaskVal, op1->AsMskCon()->gtSimdMaskVal);
33916+                     op3->AsMskCon()->gtSimdMaskVal = result;
33917+ 
33918+                     // op2 = op2 | op3
33919+                     result = {};
33920+                     EvaluateBinaryMask<simd16_t>(GT_OR, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
33921+                                                  op3->AsMskCon()->gtSimdMaskVal);
33922+                     op2->AsMskCon()->gtSimdMaskVal = result;
33923+ 
33924+                     resultNode = op2;
33925+                 }
3382333926                break;
3382433927            }
3382533928#endif // TARGET_ARM64
0 commit comments