Skip to content

Commit 20123ad

Browse files
authored
Revamp arr_slice pass by using latest replacer APIs in asdl_cpp.py (#1530)
1 parent 5e9e730 commit 20123ad

File tree

2 files changed

+150
-193
lines changed

2 files changed

+150
-193
lines changed

src/libasr/asdl_cpp.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,14 @@ def visitModule(self, mod):
422422
self.emit(" Struct& self() { return static_cast<Struct&>(*this); }")
423423
self.emit("public:")
424424
self.emit(" ASR::expr_t** current_expr;")
425+
self.emit(" SymbolTable* current_scope;")
425426
self.emit("")
426427
self.emit(" void call_replacer() {}")
428+
self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {")
429+
self.emit(" for (size_t i = 0; i < n_body; i++) {", 1)
430+
self.emit(" self().visit_stmt(*m_body[i]);", 1)
431+
self.emit(" }", 1)
432+
self.emit(" }")
427433
super(CallReplacerOnExpressionsVisitor, self).visitModule(mod)
428434
self.emit("};")
429435

@@ -440,14 +446,34 @@ def visitConstructor(self, cons, _):
440446

441447
def make_visitor(self, name, fields):
442448
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
449+
is_symtab_present = False
450+
is_stmt_present = False
451+
symtab_field_name = ""
452+
for field in fields:
453+
if field.type == "stmt":
454+
is_stmt_present = True
455+
if field.type == "symbol_table":
456+
is_symtab_present = True
457+
symtab_field_name = field.name
458+
if is_stmt_present and is_symtab_present:
459+
break
460+
if is_stmt_present and name not in ("Assignment", "ForAllSingle"):
461+
self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1)
443462
self.used = False
444-
have_body = False
463+
464+
if is_symtab_present:
465+
self.emit("SymbolTable* current_scope_copy = current_scope;", 2)
466+
self.emit("current_scope = x.m_%s;" % symtab_field_name, 2)
467+
445468
for field in fields:
446469
self.visitField(field)
447470
if not self.used:
448471
# Note: a better solution would be to change `&x` to `& /* x */`
449472
# above, but we would need to change emit to return a string.
450473
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)
474+
475+
if is_symtab_present:
476+
self.emit("current_scope = current_scope_copy;", 2)
451477
self.emit("}", 1)
452478

453479
def insert_call_replacer_code(self, name, level, index=""):
@@ -462,6 +488,9 @@ def visitField(self, field):
462488
field.type not in self.data.simple_types):
463489
level = 2
464490
if field.seq:
491+
if field.type == "stmt":
492+
self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level)
493+
return
465494
self.used = True
466495
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
467496
if field.type in products:

0 commit comments

Comments
 (0)