@@ -82,69 +82,73 @@ OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
8282
8383 if (getNumOperands () != 3 )
8484 return {};
85+
86+ // Return if the idx-th operand is a constant (inverted if necessary),
87+ // otherwise return std::nullopt.
88+ auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
89+ APInt value;
90+ if (mlir::matchPattern (getInputs ()[index], mlir::m_ConstantInt (&value)))
91+ return isInverted (index) ? ~value : value;
92+ return std::nullopt ;
93+ };
8594 if (nonConstantValues.size () == 1 ) {
8695 auto k = nonConstantValues[0 ]; // for 3 operands
8796 auto i = (k + 1 ) % 3 ;
8897 auto j = (k + 2 ) % 3 ;
89- auto c1 = adaptor. getInputs ()[i] ;
90- auto c2 = adaptor. getInputs ()[j] ;
98+ auto c1 = getConstant (i) ;
99+ auto c2 = getConstant (j) ;
91100 // x 0 0 -> 0
92101 // x 1 1 -> 1
93102 // x 0 ~0 -> x
94103 // x 1 ~1 -> x
95104 // x ~1 ~1 -> ~1 -> 0
96- // ~x 0 0 -> ~x ?
105+ // ~x 0 0 -> ~x no fold
106+ // ~x 0 ~1 -> 0
97107 if (c1 == c2) {
98108 if (isInverted (i) != isInverted (j)) {
99- if (! isInverted (k ))
100- return getOperand (k );
109+ if (isInverted (i ))
110+ return getOperand (j );
101111 else
102- return {}; // ~x ? Invert the Operand can be handled by
103- // canoncialisation?
104- } else if (isInverted (i)) {
105- auto value = cast<IntegerAttr>(c1).getValue ();
112+ return getOperand (i);
113+ }
114+ if (isInverted (i)) {
115+ // return the inverted value
116+ auto value = cast<IntegerAttr>(adaptor.getInputs ()[i]).getValue ();
106117 value = ~value;
107118 return IntegerAttr::get (
108119 IntegerType::get (getContext (), value.getBitWidth ()), value);
109120 } else
110121 return getOperand (i);
111- } else if (isInverted (i) != isInverted (j)) {
112- // ~x 0 ~1 -> 0
113- // could be bug for multi bit value
114- // fix multi bit value
115- auto value = cast<IntegerAttr>(c1).getValue ();
116- auto width = value.getBitWidth ();
117- if (width != 1 )
122+ } else {
123+ if (isInverted (k))
118124 return {};
119- if (isInverted (i))
120- return getOperand (j);
121- return getOperand (i);
122- }
123- } else if (nonConstantValues.size () == 2 ) {
124- // x x 1 -> x
125- // x ~x 1 -> 1
126- // ~x ~x 1 -> ~x
127- auto k = 3 - (nonConstantValues[0 ] + nonConstantValues[1 ]);
128- auto i = nonConstantValues[0 ];
129- auto j = nonConstantValues[1 ];
130- auto c1 = adaptor.getInputs ()[i];
131- auto c2 = adaptor.getInputs ()[j];
132- if (c1 == c2) {
133- if (isInverted (i) != isInverted (j)) {
134- if (!isInverted (k))
135- return getOperand (k);
136- auto value = cast<IntegerAttr>(adaptor.getInputs ()[k]).getValue ();
137- value = ~value;
138- return IntegerAttr::get (
139- IntegerType::get (getContext (), value.getBitWidth ()), value);
140- } else {
141- if (isInverted (i))
142- return {}; // how to return ~x?
143- else
144- return getOperand (k);
145- }
125+ else
126+ return getOperand (k);
146127 }
147128 }
129+ // else if (nonConstantValues.size() == 2) {
130+ // // x x 1 -> x
131+ // // x ~x 1 -> 1
132+ // // ~x ~x 1 -> ~x
133+ // auto k = 3 - (nonConstantValues[0] + nonConstantValues[1]);
134+ // auto i = nonConstantValues[0];
135+ // auto j = nonConstantValues[1];
136+ // auto c1 = adaptor.getInputs()[i];
137+ // auto c2 = adaptor.getInputs()[j];
138+ // if (c1 == c2) {
139+ // if (isInverted(i) != isInverted(j)) {
140+ // if (!isInverted(k))
141+ // return getOperand(k);
142+ // auto value = cast<IntegerAttr>(adaptor.getInputs()[k]).getValue();
143+ // value = ~value;
144+ // return IntegerAttr::get(
145+ // IntegerType::get(getContext(), value.getBitWidth()), value);
146+ // } else {
147+ // if(isInverted(i))return {};
148+ // else return getOperand(i);
149+ // }
150+ // }
151+ // }
148152 return {};
149153}
150154
@@ -161,15 +165,6 @@ LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
161165 if (op.getNumOperands () != 3 )
162166 return failure ();
163167
164- // Return if the idx-th operand is a constant (inverted if necessary),
165- // otherwise return std::nullopt.
166- auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
167- APInt value;
168- if (mlir::matchPattern (op.getInputs ()[index], mlir::m_ConstantInt (&value)))
169- return op.isInverted (index) ? ~value : value;
170- return std::nullopt ;
171- };
172-
173168 // Replace the op with the idx-th operand (inverted if necessary).
174169 auto replaceWithIndex = [&](int index) {
175170 bool inverted = op.isInverted (index);
@@ -197,15 +192,6 @@ LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
197192 rewriter.replaceOp (op, op.getOperand (i));
198193 return success ();
199194 }
200-
201- // If i and j are constant.
202- if (auto c1 = getConstant (i)) {
203- if (auto c2 = getConstant (j)) {
204- // If constants are complementary, we can fold.
205- if (*c1 == ~*c2)
206- return replaceWithIndex (k);
207- }
208- }
209195 }
210196 }
211197 return failure ();
0 commit comments