@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
2828 void SetUp () override {
2929 ::test::registerTestDialect (registry);
3030 context = std::make_unique<MLIRContext>(registry);
31+ builder = std::make_unique<OpBuilder>(context.get ());
3132 }
3233
3334 void testReplaceAllSymbolUses (ReplaceFnType replaceFn) {
3435 // Set up IR and find func ops.
3536 OwningOpRef<ModuleOp> module =
3637 parseSourceString<ModuleOp>(kInput , context.get ());
38+ ASSERT_TRUE (module );
3739 SymbolTable symbolTable (module .get ());
3840 auto opIterator = module ->getBody (0 )->getOperations ().begin ();
3941 auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
4648 ASSERT_TRUE (succeeded (res));
4749 ASSERT_TRUE (succeeded (verify (module .get ())));
4850
49- // Check that it got renamed.
51+ // Check that callee of the call op got renamed.
5052 bool calleeFound = false ;
5153 fooOp->walk ([&](CallOpInterface callOp) {
5254 StringAttr callee = callOp.getCallableForCallee ()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
5658 calleeFound = true ;
5759 });
5860 EXPECT_TRUE (calleeFound);
61+
62+ // Check that module attribute did *not* get renamed.
63+ auto moduleAttr = (*module )->getAttrOfType <FlatSymbolRefAttr>(" test.attr" );
64+ ASSERT_TRUE (moduleAttr);
65+ EXPECT_EQ (moduleAttr.getValue (), StringRef (" bar" ));
5966 }
6067
6168 std::unique_ptr<MLIRContext> context;
69+ std::unique_ptr<OpBuilder> builder;
6270
6371private:
6472 constexpr static llvm::StringLiteral kInput = R"MLIR(
65- module {
73+ module attributes { test.attr = @bar } {
6674 test.conversion_func_op private @foo() {
6775 "test.conversion_call_op"() { callee=@bar } : () -> ()
6876 "test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
8189 testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
8290 auto barOp) -> LogicalResult {
8391 return symbolTable.replaceAllSymbolUses (
84- barOp, StringAttr::get (context. get (), " baz" ), module );
92+ barOp, builder-> getStringAttr ( " baz" ), module );
8593 });
8694}
8795
@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
9098 testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
9199 auto barOp) -> LogicalResult {
92100 return symbolTable.replaceAllSymbolUses (
93- StringAttr::get (context.get (), " bar" ),
94- StringAttr::get (context.get (), " baz" ), module );
101+ builder->getStringAttr (" bar" ), builder->getStringAttr (" baz" ), module );
95102 });
96103}
97104
@@ -100,17 +107,17 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
100107 testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
101108 auto barOp) -> LogicalResult {
102109 return symbolTable.replaceAllSymbolUses (
103- barOp, StringAttr::get (context. get (), " baz" ), &module ->getRegion (0 ));
110+ barOp, builder-> getStringAttr ( " baz" ), &module ->getRegion (0 ));
104111 });
105112}
106113
107114TEST_F (ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108115 // Symbol as `StringAttr`, rename within module body.
109116 testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
110117 auto barOp) -> LogicalResult {
111- return symbolTable.replaceAllSymbolUses (
112- StringAttr::get (context. get (), " bar " ),
113- StringAttr::get (context. get (), " baz " ), &module ->getRegion (0 ));
118+ return symbolTable.replaceAllSymbolUses (builder-> getStringAttr ( " bar " ),
119+ builder-> getStringAttr ( " baz " ),
120+ &module ->getRegion (0 ));
114121 });
115122}
116123
@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
119126 testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
120127 auto barOp) -> LogicalResult {
121128 return symbolTable.replaceAllSymbolUses (
122- barOp, StringAttr::get (context. get (), " baz" ), fooOp);
129+ barOp, builder-> getStringAttr ( " baz" ), fooOp);
123130 });
124131}
125132
@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
128135 testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
129136 auto barOp) -> LogicalResult {
130137 return symbolTable.replaceAllSymbolUses (
131- StringAttr::get (context.get (), " bar" ),
132- StringAttr::get (context.get (), " baz" ), fooOp);
138+ builder->getStringAttr (" bar" ), builder->getStringAttr (" baz" ), fooOp);
133139 });
134140}
135141
0 commit comments