Skip to content

Sync libasr from LFortran #1670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ symbol
symbol external, identifier module_name, identifier* scope_names,
identifier original_name, access access)
| StructType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, bool is_packed,
identifier* members, abi abi, access access, bool is_packed, bool is_abstract,
expr? alignment, symbol? parent)
| EnumType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, enumtype enum_value_type,
Expand All @@ -107,7 +107,7 @@ symbol
abi abi, access access, presence presence, bool value_attr)
| ClassType(symbol_table symtab, identifier name, abi abi, access access)
| ClassProcedure(symbol_table parent_symtab, identifier name, identifier? self_argument,
identifier proc_name, symbol proc, abi abi)
identifier proc_name, symbol proc, abi abi, bool is_deferred)
| AssociateBlock(symbol_table symtab, identifier name, stmt* body)
| Block(symbol_table symtab, identifier name, stmt* body)

Expand Down Expand Up @@ -161,15 +161,15 @@ stmt
| Assign(int label, identifier variable)
| Assignment(expr target, expr value, stmt? overloaded)
| Associate(expr target, expr value)
| Cycle()
| Cycle(identifier? stmt_name)
-- deallocates if allocated otherwise throws a runtime error
| ExplicitDeallocate(expr* vars)
-- deallocates if allocated otherwise does nothing
| ImplicitDeallocate(symbol* vars)
| ImplicitDeallocate(expr* vars)
| DoConcurrentLoop(do_loop_head head, stmt* body)
| DoLoop(do_loop_head head, stmt* body)
| DoLoop(identifier? name, do_loop_head head, stmt* body)
| ErrorStop(expr? code)
| Exit()
| Exit(identifier? stmt_name)
| ForAllSingle(do_loop_head head, stmt assign_stmt)
-- GoTo points to a GoToTarget with the corresponding target_id within
-- the same procedure. We currently use `int` IDs to link GoTo with
Expand Down Expand Up @@ -201,12 +201,12 @@ stmt
| Assert(expr test, expr? msg)
| SubroutineCall(symbol name, symbol? original_name, call_arg* args, expr? dt)
| Where(expr test, stmt* body, stmt* orelse)
| WhileLoop(expr test, stmt* body)
| WhileLoop(identifier? name, expr test, stmt* body)
| Nullify(symbol* vars)
| Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat)
| ListAppend(expr a, expr ele)
| AssociateBlockCall(symbol m)
| SelectType(type_stmt* body, stmt* default)
| SelectType(expr selector, type_stmt* body, stmt* default)
| CPtrToPointer(expr cptr, expr ptr, expr? shape)
| BlockCall(int label, symbol m)
| SetInsert(expr a, expr ele)
Expand Down Expand Up @@ -420,7 +420,10 @@ do_loop_head = (expr? v, expr? start, expr? end, expr? increment)

case_stmt = CaseStmt(expr* test, stmt* body) | CaseStmt_Range(expr? start, expr? end, stmt* body)

type_stmt = TypeStmtName(symbol sym, stmt* body) | TypeStmtType(ttype type, stmt* body)
type_stmt
= TypeStmtName(symbol sym, stmt* body)
| ClassStmt(symbol sym, stmt* body)
| TypeStmtType(ttype type, stmt* body)

enumtype = IntegerConsecutiveFromZero | IntegerUnique | IntegerNotUnique | NonInteger

Expand Down
123 changes: 117 additions & 6 deletions src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,111 @@ def visitField(self, field):
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class ASRPassWalkVisitorVisitor(ASDLVisitor):

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Walk Visitor base class")
self.emit("")
self.emit("template <class Struct>")
self.emit("class ASRPassBaseWalkVisitor : public BaseVisitor<Struct>")
self.emit("{")
self.emit("private:")
self.emit(" Struct& self() { return static_cast<Struct&>(*this); }")
self.emit("public:")
self.emit(" SymbolTable* current_scope;")
self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {")
self.emit(" for (size_t i = 0; i < n_body; i++) {", 1)
self.emit(" self().visit_stmt(*m_body[i]);", 1)
self.emit(" }", 1)
self.emit("}", 1)
super(ASRPassWalkVisitorVisitor, self).visitModule(mod)
self.emit("};")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(ASRPassWalkVisitorVisitor, self).visitType(tp, tp.name)

def visitProduct(self, prod, name):
self.make_visitor(name, prod.fields)

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
is_symtab_present = False
is_stmt_present = False
symtab_field_name = ""
for field in fields:
if field.type == "stmt":
is_stmt_present = True
if field.type == "symbol_table":
is_symtab_present = True
symtab_field_name = field.name
if is_stmt_present and is_symtab_present:
break
if is_stmt_present and name not in ("Assignment", "ForAllSingle"):
self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1)
self.used = False

if is_symtab_present:
self.emit("SymbolTable* current_scope_copy = current_scope;", 2)
self.emit("current_scope = x.m_%s;" % symtab_field_name, 2)

for field in fields:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)

if is_symtab_present:
self.emit("current_scope = current_scope_copy;", 2)

self.emit("}", 1)

def visitField(self, field):
if (field.type not in asdl.builtin_types and
field.type not in self.data.simple_types):
level = 2
if field.seq:
if field.type == "stmt":
self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level)
return
self.used = True
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
else:
if field.type != "symbol":
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
self.emit("}", level)
else:
if field.type in products:
self.used = True
if field.opt:
self.emit("if (x.m_%s)" % field.name, 2)
level = 3
if field.opt:
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
else:
self.emit("self().visit_%s(x.m_%s);" % (field.type, field.name), level)
else:
if field.type != "symbol":
self.used = True
if field.opt:
self.emit("if (x.m_%s)" % field.name, 2)
level = 3
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
elif field.type == "symbol_table" and field.name in["symtab",
"global_scope"]:
self.used = True
self.emit("for (auto &a : x.m_%s->get_scope()) {" % field.name, 2)
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class CallReplacerOnExpressionsVisitor(ASDLVisitor):

def __init__(self, stream, data):
Expand Down Expand Up @@ -477,10 +582,10 @@ def make_visitor(self, name, fields):
self.emit("}", 1)

def insert_call_replacer_code(self, name, level, index=""):
self.emit(" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self.current_expr_copy_variable_count), level)
self.emit(" current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
self.emit(" self().call_replacer();", level)
self.emit(" current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
self.emit("ASR::expr_t** current_expr_copy_%d = current_expr;" % (self.current_expr_copy_variable_count), level)
self.emit("current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
self.emit("self().call_replacer();", level)
self.emit("current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
self.current_expr_copy_variable_count += 1

def visitField(self, field):
Expand All @@ -495,12 +600,14 @@ def visitField(self, field):
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
if field.type == "expr":
self.insert_call_replacer_code(field.name, level, "[i]")
self.insert_call_replacer_code(field.name, level + 1, "[i]")
self.emit("if( x.m_%s[i] )" % (field.name), level)
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
else:
if field.type != "symbol":
if field.type == "expr":
self.insert_call_replacer_code(field.name, level, "[i]")
self.insert_call_replacer_code(field.name, level + 1, "[i]")
self.emit("if( x.m_%s[i] )" % (field.name), level + 1)
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
self.emit("}", level)
else:
Expand All @@ -511,6 +618,7 @@ def visitField(self, field):
level = 3
if field.type == "expr":
self.insert_call_replacer_code(field.name, level)
self.emit("if( x.m_%s )" % (field.name), level)
if field.opt:
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
self.emit("}", 2)
Expand All @@ -524,6 +632,7 @@ def visitField(self, field):
level = 3
if field.type == "expr":
self.insert_call_replacer_code(field.name, level)
self.emit("if( x.m_%s )" % (field.name), level)
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
if field.opt:
self.emit("}", 2)
Expand Down Expand Up @@ -2595,6 +2704,8 @@ def main(argv):

try:
if is_asr:
ASRPassWalkVisitorVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprStmtDuplicatorVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprBaseReplacerVisitor(fp, data).visit(mod)
Expand Down
14 changes: 14 additions & 0 deletions src/libasr/asr_scopes.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,21 @@ struct SymbolTable {
scope.erase(name);
}

// Add a new symbol that did not exist before
void add_symbol(const std::string &name, ASR::symbol_t* symbol) {
LCOMPILERS_ASSERT(scope.find(name) == scope.end())
scope[name] = symbol;
}

// Overwrite an existing symbol
void overwrite_symbol(const std::string &name, ASR::symbol_t* symbol) {
LCOMPILERS_ASSERT(scope.find(name) != scope.end())
scope[name] = symbol;
}

// Use as the last resort, prefer to always either add a new symbol
// or overwrite an existing one, not both
void add_or_overwrite_symbol(const std::string &name, ASR::symbol_t* symbol) {
scope[name] = symbol;
}

Expand Down
Loading