Skip to content

Commit 8c4a07c

Browse files
committed
[DAGCombiner] Fold fold (fp_to_bf16 (bf16_to_fp op)) -> op
1 parent aaff3fb commit 8c4a07c

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ namespace {
510510
SDValue visitMSCATTER(SDNode *N);
511511
SDValue visitFP_TO_FP16(SDNode *N);
512512
SDValue visitFP16_TO_FP(SDNode *N);
513+
SDValue visitFP_TO_BF16(SDNode *N);
513514
SDValue visitVECREDUCE(SDNode *N);
514515
SDValue visitVPOp(SDNode *N);
515516

@@ -1746,6 +1747,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
17461747
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
17471748
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
17481749
case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
1750+
case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
17491751
case ISD::FREEZE: return visitFREEZE(N);
17501752
case ISD::VECREDUCE_FADD:
17511753
case ISD::VECREDUCE_FMUL:
@@ -23072,6 +23074,16 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
2307223074
return SDValue();
2307323075
}
2307423076

23077+
SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
23078+
SDValue N0 = N->getOperand(0);
23079+
23080+
// fold (fp_to_bf16 (bf16_to_fp op)) -> op
23081+
if (N0->getOpcode() == ISD::BF16_TO_FP)
23082+
return N0->getOperand(0);
23083+
23084+
return SDValue();
23085+
}
23086+
2307523087
SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
2307623088
SDValue N0 = N->getOperand(0);
2307723089
EVT VT = N0.getValueType();

llvm/test/CodeGen/X86/bfloat.ll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,16 @@ define void @store_constant(ptr %pc) {
100100
store bfloat 1.0, ptr %pc
101101
ret void
102102
}
103+
104+
define void @fold_ext_trunc(ptr %pa, ptr %pc) {
105+
; CHECK-LABEL: fold_ext_trunc:
106+
; CHECK: # %bb.0:
107+
; CHECK-NEXT: movzwl (%rdi), %eax
108+
; CHECK-NEXT: movw %ax, (%rsi)
109+
; CHECK-NEXT: retq
110+
%a = load bfloat, ptr %pa
111+
%ext = fpext bfloat %a to float
112+
%trunc = fptrunc float %ext to bfloat
113+
store bfloat %trunc, ptr %pc
114+
ret void
115+
}

0 commit comments

Comments
 (0)