@@ -91,34 +91,40 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
9191}
9292
9393// / Expands tanh op into
94- // / 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95- // / 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
94+ // / 1-exp^{-2x} / 1+exp^{-2x}
95+ // / To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
96+ // / We compute a "signs" value which is -1 if input is negative and +1 if input
97+ // / is positive. Then multiply the input by this value, guaranteeing that the
98+ // / result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
99+ // / 1]. Expand the computation on the input `x * sign(x)`, then multiply the
100+ // / result by `sign(x)` to retain sign of the real result.
96101static LogicalResult convertTanhOp (math::TanhOp op, PatternRewriter &rewriter) {
97102 auto floatType = op.getOperand ().getType ();
98103 Location loc = op.getLoc ();
104+ Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
99105 Value one = createFloatConst (loc, floatType, 1.0 , rewriter);
100- Value two = createFloatConst (loc, floatType, 2.0 , rewriter);
101- Value doubledX = rewriter.create <arith::MulFOp>(loc, op.getOperand (), two);
106+ Value negTwo = createFloatConst (loc, floatType, -2.0 , rewriter);
107+
108+ // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
109+ Value sign = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
110+ op.getOperand (), zero);
111+ sign = rewriter.create <arith::SIToFPOp>(loc, floatType, sign);
112+ sign = rewriter.create <arith::MulFOp>(loc, sign, negTwo);
113+ sign = rewriter.create <arith::AddFOp>(loc, sign, one);
102114
103- // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104- Value negDoubledX = rewriter.create <arith::NegFOp>(loc, doubledX);
115+ // Normalize input to positive value: y = sign(x) * x
116+ Value positiveX = rewriter.create <arith::MulFOp>(loc, sign, op.getOperand ());
117+
118+ // Decompose on normalized input
119+ Value negDoubledX = rewriter.create <arith::MulFOp>(loc, negTwo, positiveX);
105120 Value exp2x = rewriter.create <math::ExpOp>(loc, negDoubledX);
106121 Value dividend = rewriter.create <arith::SubFOp>(loc, one, exp2x);
107122 Value divisor = rewriter.create <arith::AddFOp>(loc, one, exp2x);
108123 Value positiveRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
109124
110- // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111- exp2x = rewriter.create <math::ExpOp>(loc, doubledX);
112- dividend = rewriter.create <arith::SubFOp>(loc, exp2x, one);
113- divisor = rewriter.create <arith::AddFOp>(loc, exp2x, one);
114- Value negativeRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
125+ // Multiply result by sign(x) to retain signs from negative inputs
126+ rewriter.replaceOpWithNewOp <arith::MulFOp>(op, sign, positiveRes);
115127
116- // tanh(x) = x >= 0 ? positiveRes : negativeRes
117- Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
118- Value cmpRes = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119- op.getOperand (), zero);
120- rewriter.replaceOpWithNewOp <arith::SelectOp>(op, cmpRes, positiveRes,
121- negativeRes);
122128 return success ();
123129}
124130
0 commit comments