Skip to content

[X86] Combine store + vselect to masked_store #145176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

abhishek-kaushik22
Copy link
Contributor

Add a new combine to replace

(store ch (vselect cond truevec (load ch ptr offset)) ptr offset)

to

(mstore ch truevec ptr offset cond)

This saves a blend operation on targets that support conditional stores.

Add a new combine to replace
```
(store ch (vselect cond truevec (load ch ptr offset)) ptr offset)
```
to
```
(mstore ch truevec ptr offset cond)
```
@abhishek-kaushik22 abhishek-kaushik22 changed the title [X86] Combine store + vselect to masked_store` [X86] Combine store + vselect to masked_store Jun 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-backend-x86

Author: Abhishek Kaushik (abhishek-kaushik22)

Changes

Add a new combine to replace

(store ch (vselect cond truevec (load ch ptr offset)) ptr offset)

to

(mstore ch truevec ptr offset cond)

This saves a blend operation on targets that support conditional stores.


Full diff: https://github.com/llvm/llvm-project/pull/145176.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+78)
  • (added) llvm/test/CodeGen/X86/combine-storetomstore.ll (+276)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 33083c0eba695..7a8ec1b25de62 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -66,6 +66,7 @@
 #include <bitset>
 #include <cctype>
 #include <numeric>
+#include <queue>
 using namespace llvm;
 
 #define DEBUG_TYPE "x86-isel"
@@ -53403,6 +53404,80 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
+                                 const SDLoc &Dl,
+                                 const X86Subtarget &Subtarget) {
+  if (!Subtarget.hasAVX() && !Subtarget.hasAVX2() && !Subtarget.hasAVX512())
+    return SDValue();
+
+  if (!Store->isSimple())
+    return SDValue();
+
+  SDValue StoredVal = Store->getValue();
+  SDValue StorePtr = Store->getBasePtr();
+  SDValue StoreOffset = Store->getOffset();
+  EVT VT = StoredVal.getValueType();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  if (!TLI.isTypeLegal(VT) || !TLI.isOperationLegalOrCustom(ISD::MSTORE, VT))
+    return SDValue();
+
+  if (StoredVal.getOpcode() != ISD::VSELECT)
+    return SDValue();
+
+  SDValue Mask = StoredVal.getOperand(0);
+  SDValue TrueVec = StoredVal.getOperand(1);
+  SDValue FalseVec = StoredVal.getOperand(2);
+
+  LoadSDNode *Load = cast<LoadSDNode>(FalseVec.getNode());
+  if (!Load || !Load->isSimple())
+    return SDValue();
+
+  SDValue LoadPtr = Load->getBasePtr();
+  SDValue LoadOffset = Load->getOffset();
+
+  if (StorePtr != LoadPtr || StoreOffset != LoadOffset)
+    return SDValue();
+
+  auto IsSafeToFold = [](StoreSDNode *Store, LoadSDNode *Load) {
+    std::queue<SDValue> Worklist;
+
+    Worklist.push(Store->getChain());
+
+    while (!Worklist.empty()) {
+      SDValue Chain = Worklist.front();
+      Worklist.pop();
+
+      SDNode *Node = Chain.getNode();
+      if (!Node)
+        return false;
+
+      if (const auto *MemNode = dyn_cast<MemSDNode>(Node))
+        if (!MemNode->isSimple() || MemNode->writeMem())
+          return false;
+
+      if (Node == Load)
+        return true;
+
+      if (Node->getOpcode() == ISD::TokenFactor) {
+        for (unsigned i = 0; i < Node->getNumOperands(); ++i)
+          Worklist.push(Node->getOperand(i));
+      } else {
+        Worklist.push(Node->getOperand(0));
+      }
+    }
+
+    return false;
+  };
+
+  if (!IsSafeToFold(Store, Load))
+    return SDValue();
+
+  return DAG.getMaskedStore(Store->getChain(), Dl, TrueVec, StorePtr,
+                            StoreOffset, Mask, Store->getMemoryVT(),
+                            Store->getMemOperand(), Store->getAddressingMode());
+}
+
 static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
                             TargetLowering::DAGCombinerInfo &DCI,
                             const X86Subtarget &Subtarget) {
@@ -53728,6 +53803,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
                         St->getMemOperand()->getFlags());
   }
 
+  if (SDValue MaskedStore = foldToMaskedStore(St, DAG, dl, Subtarget))
+    return MaskedStore;
+
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/X86/combine-storetomstore.ll b/llvm/test/CodeGen/X86/combine-storetomstore.ll
new file mode 100644
index 0000000000000..75d0dd85cafda
--- /dev/null
+++ b/llvm/test/CodeGen/X86/combine-storetomstore.ll
@@ -0,0 +1,276 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx     | FileCheck %s -check-prefix=AVX
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2    | FileCheck %s -check-prefix=AVX2
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx512f | FileCheck %s -check-prefix=AVX512
+
+
+define void @test_masked_store_success(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; AVX-LABEL: test_masked_store_success:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm2 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
+; AVX-NEXT:    vpslld $31, %xmm2, %xmm2
+; AVX-NEXT:    vpunpckhwd {{.*#+}} xmm1 = xmm1[4,4,5,5,6,6,7,7]
+; AVX-NEXT:    vpslld $31, %xmm1, %xmm1
+; AVX-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX-NEXT:    vmaskmovps %ymm0, %ymm1, (%rdi)
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+;
+; AVX2-LABEL: test_masked_store_success:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVX2-NEXT:    vpslld $31, %ymm1, %ymm1
+; AVX2-NEXT:    vpmaskmovd %ymm0, %ymm1, (%rdi)
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: test_masked_store_success:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; AVX512-NEXT:    vpmovsxwq %xmm1, %zmm1
+; AVX512-NEXT:    vpsllq $63, %zmm1, %zmm1
+; AVX512-NEXT:    vptestmq %zmm1, %zmm1, %k1
+; AVX512-NEXT:    vmovdqu32 %zmm0, (%rdi) {%k1}
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %load = load <8 x i32>, ptr %ptr, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+define void @test_masked_store_volatile_load(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; AVX-LABEL: test_masked_store_volatile_load:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm2 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
+; AVX-NEXT:    vpslld $31, %xmm2, %xmm2
+; AVX-NEXT:    vpunpckhwd {{.*#+}} xmm1 = xmm1[4,4,5,5,6,6,7,7]
+; AVX-NEXT:    vpslld $31, %xmm1, %xmm1
+; AVX-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX-NEXT:    vmovaps (%rdi), %ymm2
+; AVX-NEXT:    vblendvps %ymm1, %ymm0, %ymm2, %ymm0
+; AVX-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+;
+; AVX2-LABEL: test_masked_store_volatile_load:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVX2-NEXT:    vpslld $31, %ymm1, %ymm1
+; AVX2-NEXT:    vmovaps (%rdi), %ymm2
+; AVX2-NEXT:    vblendvps %ymm1, %ymm0, %ymm2, %ymm0
+; AVX2-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: test_masked_store_volatile_load:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; AVX512-NEXT:    vpmovsxwq %xmm1, %zmm1
+; AVX512-NEXT:    vpsllq $63, %zmm1, %zmm1
+; AVX512-NEXT:    vptestmq %zmm1, %zmm1, %k1
+; AVX512-NEXT:    vmovdqa (%rdi), %ymm1
+; AVX512-NEXT:    vmovdqa32 %zmm0, %zmm1 {%k1}
+; AVX512-NEXT:    vmovdqa %ymm1, (%rdi)
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %load = load volatile <8 x i32>, ptr %ptr, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+define void @test_masked_store_volatile_store(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; AVX-LABEL: test_masked_store_volatile_store:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm2 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
+; AVX-NEXT:    vpslld $31, %xmm2, %xmm2
+; AVX-NEXT:    vpunpckhwd {{.*#+}} xmm1 = xmm1[4,4,5,5,6,6,7,7]
+; AVX-NEXT:    vpslld $31, %xmm1, %xmm1
+; AVX-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX-NEXT:    vmovaps (%rdi), %ymm2
+; AVX-NEXT:    vblendvps %ymm1, %ymm0, %ymm2, %ymm0
+; AVX-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+;
+; AVX2-LABEL: test_masked_store_volatile_store:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVX2-NEXT:    vpslld $31, %ymm1, %ymm1
+; AVX2-NEXT:    vmovaps (%rdi), %ymm2
+; AVX2-NEXT:    vblendvps %ymm1, %ymm0, %ymm2, %ymm0
+; AVX2-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: test_masked_store_volatile_store:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; AVX512-NEXT:    vpmovsxwq %xmm1, %zmm1
+; AVX512-NEXT:    vpsllq $63, %zmm1, %zmm1
+; AVX512-NEXT:    vptestmq %zmm1, %zmm1, %k1
+; AVX512-NEXT:    vmovdqa (%rdi), %ymm1
+; AVX512-NEXT:    vmovdqa32 %zmm0, %zmm1 {%k1}
+; AVX512-NEXT:    vmovdqa %ymm1, (%rdi)
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %load = load <8 x i32>, ptr %ptr, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store volatile <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+declare void @use_vec(<8 x i32>)
+
+define void @test_masked_store_intervening(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; AVX-LABEL: test_masked_store_intervening:
+; AVX:       # %bb.0:
+; AVX-NEXT:    pushq %rbx
+; AVX-NEXT:    .cfi_def_cfa_offset 16
+; AVX-NEXT:    subq $32, %rsp
+; AVX-NEXT:    .cfi_def_cfa_offset 48
+; AVX-NEXT:    .cfi_offset %rbx, -16
+; AVX-NEXT:    movq %rdi, %rbx
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm2 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
+; AVX-NEXT:    vpslld $31, %xmm2, %xmm2
+; AVX-NEXT:    vpunpckhwd {{.*#+}} xmm1 = xmm1[4,4,5,5,6,6,7,7]
+; AVX-NEXT:    vpslld $31, %xmm1, %xmm1
+; AVX-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX-NEXT:    vmovaps (%rdi), %ymm2
+; AVX-NEXT:    vblendvps %ymm1, %ymm0, %ymm2, %ymm0
+; AVX-NEXT:    vmovups %ymm0, (%rsp) # 32-byte Spill
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
+; AVX-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX-NEXT:    callq use_vec@PLT
+; AVX-NEXT:    vmovups (%rsp), %ymm0 # 32-byte Reload
+; AVX-NEXT:    vmovaps %ymm0, (%rbx)
+; AVX-NEXT:    addq $32, %rsp
+; AVX-NEXT:    .cfi_def_cfa_offset 16
+; AVX-NEXT:    popq %rbx
+; AVX-NEXT:    .cfi_def_cfa_offset 8
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+;
+; AVX2-LABEL: test_masked_store_intervening:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    pushq %rbx
+; AVX2-NEXT:    .cfi_def_cfa_offset 16
+; AVX2-NEXT:    subq $32, %rsp
+; AVX2-NEXT:    .cfi_def_cfa_offset 48
+; AVX2-NEXT:    .cfi_offset %rbx, -16
+; AVX2-NEXT:    movq %rdi, %rbx
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVX2-NEXT:    vpslld $31, %ymm1, %ymm1
+; AVX2-NEXT:    vmovaps (%rdi), %ymm2
+; AVX2-NEXT:    vblendvps %ymm1, %ymm0, %ymm2, %ymm0
+; AVX2-NEXT:    vmovups %ymm0, (%rsp) # 32-byte Spill
+; AVX2-NEXT:    vxorps %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX2-NEXT:    callq use_vec@PLT
+; AVX2-NEXT:    vmovups (%rsp), %ymm0 # 32-byte Reload
+; AVX2-NEXT:    vmovaps %ymm0, (%rbx)
+; AVX2-NEXT:    addq $32, %rsp
+; AVX2-NEXT:    .cfi_def_cfa_offset 16
+; AVX2-NEXT:    popq %rbx
+; AVX2-NEXT:    .cfi_def_cfa_offset 8
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: test_masked_store_intervening:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    pushq %rbx
+; AVX512-NEXT:    .cfi_def_cfa_offset 16
+; AVX512-NEXT:    subq $144, %rsp
+; AVX512-NEXT:    .cfi_def_cfa_offset 160
+; AVX512-NEXT:    .cfi_offset %rbx, -16
+; AVX512-NEXT:    movq %rdi, %rbx
+; AVX512-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; AVX512-NEXT:    vmovups %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
+; AVX512-NEXT:    vpmovsxwq %xmm1, %zmm0
+; AVX512-NEXT:    vpsllq $63, %zmm0, %zmm0
+; AVX512-NEXT:    vptestmq %zmm0, %zmm0, %k1
+; AVX512-NEXT:    kmovw %k1, {{[-0-9]+}}(%r{{[sb]}}p) # 2-byte Spill
+; AVX512-NEXT:    vmovaps (%rdi), %ymm0
+; AVX512-NEXT:    vmovups %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
+; AVX512-NEXT:    vxorps %xmm0, %xmm0, %xmm0
+; AVX512-NEXT:    vmovaps %ymm0, (%rdi)
+; AVX512-NEXT:    callq use_vec@PLT
+; AVX512-NEXT:    vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload
+; AVX512-NEXT:    vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload
+; AVX512-NEXT:    kmovw {{[-0-9]+}}(%r{{[sb]}}p), %k1 # 2-byte Reload
+; AVX512-NEXT:    vmovdqa32 %zmm0, %zmm1 {%k1}
+; AVX512-NEXT:    vmovdqa %ymm1, (%rbx)
+; AVX512-NEXT:    addq $144, %rsp
+; AVX512-NEXT:    .cfi_def_cfa_offset 16
+; AVX512-NEXT:    popq %rbx
+; AVX512-NEXT:    .cfi_def_cfa_offset 8
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %load = load <8 x i32>, ptr %ptr, align 32
+  store <8 x i32> zeroinitializer, ptr %ptr, align 32
+  %tmp = load <8 x i32>, ptr %ptr
+  call void @use_vec(<8 x i32> %tmp)
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+
+define void @test_masked_store_multiple(<8 x i32> %x, <8 x i32> %y, ptr %ptr1, ptr %ptr2, <8 x i1> %cmp, <8 x i1> %cmp2) {
+; AVX-LABEL: foo:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm4 = xmm2[0],zero,xmm2[1],zero,xmm2[2],zero,xmm2[3],zero
+; AVX-NEXT:    vpslld $31, %xmm4, %xmm4
+; AVX-NEXT:    vpunpckhwd {{.*#+}} xmm2 = xmm2[4,4,5,5,6,6,7,7]
+; AVX-NEXT:    vpslld $31, %xmm2, %xmm2
+; AVX-NEXT:    vinsertf128 $1, %xmm2, %ymm4, %ymm2
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm4 = xmm3[0],zero,xmm3[1],zero,xmm3[2],zero,xmm3[3],zero
+; AVX-NEXT:    vpslld $31, %xmm4, %xmm4
+; AVX-NEXT:    vpunpckhwd {{.*#+}} xmm3 = xmm3[4,4,5,5,6,6,7,7]
+; AVX-NEXT:    vpslld $31, %xmm3, %xmm3
+; AVX-NEXT:    vinsertf128 $1, %xmm3, %ymm4, %ymm3
+; AVX-NEXT:    vmovaps (%rsi), %ymm4
+; AVX-NEXT:    vblendvps %ymm3, %ymm1, %ymm4, %ymm1
+; AVX-NEXT:    vmaskmovps %ymm0, %ymm2, (%rdi)
+; AVX-NEXT:    vmovaps %ymm1, (%rsi)
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+;
+; AVX2-LABEL: foo:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm2 = xmm2[0],zero,xmm2[1],zero,xmm2[2],zero,xmm2[3],zero,xmm2[4],zero,xmm2[5],zero,xmm2[6],zero,xmm2[7],zero
+; AVX2-NEXT:    vpslld $31, %ymm2, %ymm2
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm3 = xmm3[0],zero,xmm3[1],zero,xmm3[2],zero,xmm3[3],zero,xmm3[4],zero,xmm3[5],zero,xmm3[6],zero,xmm3[7],zero
+; AVX2-NEXT:    vpslld $31, %ymm3, %ymm3
+; AVX2-NEXT:    vmovaps (%rsi), %ymm4
+; AVX2-NEXT:    vblendvps %ymm3, %ymm1, %ymm4, %ymm1
+; AVX2-NEXT:    vpmaskmovd %ymm0, %ymm2, (%rdi)
+; AVX2-NEXT:    vmovaps %ymm1, (%rsi)
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: foo:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    # kill: def $ymm1 killed $ymm1 def $zmm1
+; AVX512-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; AVX512-NEXT:    vpmovsxwq %xmm2, %zmm2
+; AVX512-NEXT:    vpsllq $63, %zmm2, %zmm2
+; AVX512-NEXT:    vptestmq %zmm2, %zmm2, %k1
+; AVX512-NEXT:    vpmovsxwq %xmm3, %zmm2
+; AVX512-NEXT:    vpsllq $63, %zmm2, %zmm2
+; AVX512-NEXT:    vptestmq %zmm2, %zmm2, %k2
+; AVX512-NEXT:    vmovdqa (%rsi), %ymm2
+; AVX512-NEXT:    vmovdqa32 %zmm1, %zmm2 {%k2}
+; AVX512-NEXT:    vmovdqu32 %zmm0, (%rdi) {%k1}
+; AVX512-NEXT:    vmovdqa %ymm2, (%rsi)
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %load = load <8 x i32>, ptr %ptr1, align 32
+  %load2 = load <8 x i32>, ptr %ptr2, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  %sel2 = select <8 x i1> %cmp2, <8 x i32> %y, <8 x i32> %load2
+  store <8 x i32> %sel, ptr %ptr1, align 32
+  store <8 x i32> %sel2, ptr %ptr2, align 32
+  ret void
+}

@phoebewang
Copy link
Contributor

Can we use a pattern match instead?

@abhishek-kaushik22
Copy link
Contributor Author

Can we use a pattern match instead?

There was no m_Load to match the Load, so I've added one.

inline TernaryOpc_match<T0_P, T1_P, T2_P>
m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::LOAD, Ch, Ptr, Offset);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you pull this out into its own PR and add unittests to SelectionDAGPatternMatchTest.cpp

@@ -53403,6 +53404,76 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
const SDLoc &Dl,
const X86Subtarget &Subtarget) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we doing this in the DAG and not earlier on in SLP/VectorCombine/InstCombine ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried in instcombine first #144298

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - but there doesn't seem to be any reason for this to be in X86ISelLowering and not DAGCombiner::visitSTORE - the Subtarget checks looks to be unnecessary based on the ISD::MSTORE legality checks below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ISD::MSTORE may be legal or custom on many targets that do not have a conditional store instruction, for example in case of sse2 it's lowered as checking the mask bits individually and storing each scalar value.
I might be missing some other way of checking for it, please suggest if that's the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The legality check helps simplify the mask that is passed to mstore, before type legalization it's usually a vXi1 which may not be legal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setOperationAction(ISD::MSTORE is not set for anything prior to AVX - are you referring to tests with @llvm.masked.store intrinsics that have been legalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was thinking that this would translate to a mstore in initial DAG and then custom lowered

define void @foo(<8 x i32> %x, ptr %ptr1, <8 x i1> %cmp) {
  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %x, ptr %ptr1, i32 32, <8 x i1> %cmp)
  ret void
}

Makes sense to put this in DAGCombiner, but I don't have any way to test on non-x86 platforms

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding ARM MVE or AArch64 test coverage should be very straightforward.

@phoebewang
Copy link
Contributor

Can we use a pattern match instead?

There was no m_Load to match the Load, so I've added one.

I mean add

def: Pat<(st_frag (_.VT (vselect_mask _.KRCWM:$mask, (_.VT _.RC:$src), (_.VT (ld_frag addr:$dst)))), addr:$dst), (!cast<Instruction>(Name#_.ZSuffix#mrk) addr:$ptr, _.KRCWM:$mask, _.RC:$src)>;
def: Pat<(st_frag (_.VT (vselect_mask _.KRCWM:$mask, (_.VT _.RC:$src), _.ImmAllZerosV)), addr:$dst), (!cast<Instruction>(Name#_.ZSuffix#mrkz) addr:$ptr, _.KRCWM:$mask, _.RC:$src)>;

to avx512_store.

@abhishek-kaushik22
Copy link
Contributor Author

Can we use a pattern match instead?

There was no m_Load to match the Load, so I've added one.

I mean add

def: Pat<(st_frag (_.VT (vselect_mask _.KRCWM:$mask, (_.VT _.RC:$src), (_.VT (ld_frag addr:$dst)))), addr:$dst), (!cast<Instruction>(Name#_.ZSuffix#mrk) addr:$ptr, _.KRCWM:$mask, _.RC:$src)>;
def: Pat<(st_frag (_.VT (vselect_mask _.KRCWM:$mask, (_.VT _.RC:$src), _.ImmAllZerosV)), addr:$dst), (!cast<Instruction>(Name#_.ZSuffix#mrkz) addr:$ptr, _.KRCWM:$mask, _.RC:$src)>;

to avx512_store.

Is that safe to do? There might be other stores in between the store and load, and we don't have any alias information

const X86Subtarget &Subtarget) {
using namespace llvm::SDPatternMatch;

if (!Subtarget.hasAVX() && !Subtarget.hasAVX2() && !Subtarget.hasAVX512())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check hasAVX only.

if (!MemNode->isSimple() || MemNode->writeMem())
return false;

if (Node->getOpcode() == ISD::TokenFactor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define void @test_masked_store_multiple(<8 x i32> %x, <8 x i32> %y, ptr %ptr1, ptr %ptr2, <8 x i1> %cmp, <8 x i1> %cmp2) {
  %load = load <8 x i32>, ptr %ptr1, align 32
  %load2 = load <8 x i32>, ptr %ptr2, align 32
  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
  %sel2 = select <8 x i1> %cmp2, <8 x i32> %y, <8 x i32> %load2
  store <8 x i32> %sel, ptr %ptr1, align 32
  store <8 x i32> %sel2, ptr %ptr2, align 32
  ret void
}

This test generates a DAG which has a TokenFactor node. https://godbolt.org/z/zdTce1xYE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants