Skip to content

Revamp arr_slice pass by using latest replacer APIs in asdl_cpp.py #1530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,14 @@ def visitModule(self, mod):
self.emit(" Struct& self() { return static_cast<Struct&>(*this); }")
self.emit("public:")
self.emit(" ASR::expr_t** current_expr;")
self.emit(" SymbolTable* current_scope;")
self.emit("")
self.emit(" void call_replacer() {}")
self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {")
self.emit(" for (size_t i = 0; i < n_body; i++) {", 1)
self.emit(" self().visit_stmt(*m_body[i]);", 1)
self.emit(" }", 1)
self.emit(" }")
super(CallReplacerOnExpressionsVisitor, self).visitModule(mod)
self.emit("};")

Expand All @@ -440,14 +446,34 @@ def visitConstructor(self, cons, _):

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
is_symtab_present = False
is_stmt_present = False
symtab_field_name = ""
for field in fields:
if field.type == "stmt":
is_stmt_present = True
if field.type == "symbol_table":
is_symtab_present = True
symtab_field_name = field.name
if is_stmt_present and is_symtab_present:
break
if is_stmt_present and name not in ("Assignment", "ForAllSingle"):
self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1)
self.used = False
have_body = False

if is_symtab_present:
self.emit("SymbolTable* current_scope_copy = current_scope;", 2)
self.emit("current_scope = x.m_%s;" % symtab_field_name, 2)

for field in fields:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)

if is_symtab_present:
self.emit("current_scope = current_scope_copy;", 2)
self.emit("}", 1)

def insert_call_replacer_code(self, name, level, index=""):
Expand All @@ -462,6 +488,9 @@ def visitField(self, field):
field.type not in self.data.simple_types):
level = 2
if field.seq:
if field.type == "stmt":
self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level)
return
self.used = True
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
Expand Down
Loading