Skip to content

Commit 0f6ccd0

Browse files
authored
Merge branch 'main' into dict-funcs
2 parents a8c32a7 + 3e8af73 commit 0f6ccd0

File tree

72 files changed

+2173
-471
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2173
-471
lines changed

src/libasr/ASR.asdl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ symbol
9595
symbol external, identifier module_name, identifier* scope_names,
9696
identifier original_name, access access)
9797
| StructType(symbol_table symtab, identifier name, identifier* dependencies,
98-
identifier* members, abi abi, access access, bool is_packed,
98+
identifier* members, abi abi, access access, bool is_packed, bool is_abstract,
9999
expr? alignment, symbol? parent)
100100
| EnumType(symbol_table symtab, identifier name, identifier* dependencies,
101101
identifier* members, abi abi, access access, enumtype enum_value_type,
@@ -107,7 +107,7 @@ symbol
107107
abi abi, access access, presence presence, bool value_attr)
108108
| ClassType(symbol_table symtab, identifier name, abi abi, access access)
109109
| ClassProcedure(symbol_table parent_symtab, identifier name, identifier? self_argument,
110-
identifier proc_name, symbol proc, abi abi)
110+
identifier proc_name, symbol proc, abi abi, bool is_deferred)
111111
| AssociateBlock(symbol_table symtab, identifier name, stmt* body)
112112
| Block(symbol_table symtab, identifier name, stmt* body)
113113

@@ -161,15 +161,15 @@ stmt
161161
| Assign(int label, identifier variable)
162162
| Assignment(expr target, expr value, stmt? overloaded)
163163
| Associate(expr target, expr value)
164-
| Cycle()
164+
| Cycle(identifier? stmt_name)
165165
-- deallocates if allocated otherwise throws a runtime error
166166
| ExplicitDeallocate(expr* vars)
167167
-- deallocates if allocated otherwise does nothing
168-
| ImplicitDeallocate(symbol* vars)
168+
| ImplicitDeallocate(expr* vars)
169169
| DoConcurrentLoop(do_loop_head head, stmt* body)
170-
| DoLoop(do_loop_head head, stmt* body)
170+
| DoLoop(identifier? name, do_loop_head head, stmt* body)
171171
| ErrorStop(expr? code)
172-
| Exit()
172+
| Exit(identifier? stmt_name)
173173
| ForAllSingle(do_loop_head head, stmt assign_stmt)
174174
-- GoTo points to a GoToTarget with the corresponding target_id within
175175
-- the same procedure. We currently use `int` IDs to link GoTo with
@@ -201,12 +201,12 @@ stmt
201201
| Assert(expr test, expr? msg)
202202
| SubroutineCall(symbol name, symbol? original_name, call_arg* args, expr? dt)
203203
| Where(expr test, stmt* body, stmt* orelse)
204-
| WhileLoop(expr test, stmt* body)
204+
| WhileLoop(identifier? name, expr test, stmt* body)
205205
| Nullify(symbol* vars)
206206
| Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat)
207207
| ListAppend(expr a, expr ele)
208208
| AssociateBlockCall(symbol m)
209-
| SelectType(type_stmt* body, stmt* default)
209+
| SelectType(expr selector, type_stmt* body, stmt* default)
210210
| CPtrToPointer(expr cptr, expr ptr, expr? shape)
211211
| BlockCall(int label, symbol m)
212212
| SetInsert(expr a, expr ele)
@@ -420,7 +420,10 @@ do_loop_head = (expr? v, expr? start, expr? end, expr? increment)
420420

421421
case_stmt = CaseStmt(expr* test, stmt* body) | CaseStmt_Range(expr? start, expr? end, stmt* body)
422422

423-
type_stmt = TypeStmtName(symbol sym, stmt* body) | TypeStmtType(ttype type, stmt* body)
423+
type_stmt
424+
= TypeStmtName(symbol sym, stmt* body)
425+
| ClassStmt(symbol sym, stmt* body)
426+
| TypeStmtType(ttype type, stmt* body)
424427

425428
enumtype = IntegerConsecutiveFromZero | IntegerUnique | IntegerNotUnique | NonInteger
426429

src/libasr/asdl_cpp.py

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,111 @@ def visitField(self, field):
405405
self.emit( "this->visit_symbol(*a.second);", 3)
406406
self.emit("}", 2)
407407

408+
class ASRPassWalkVisitorVisitor(ASDLVisitor):
409+
410+
def visitModule(self, mod):
411+
self.emit("/" + "*"*78 + "/")
412+
self.emit("// Walk Visitor base class")
413+
self.emit("")
414+
self.emit("template <class Struct>")
415+
self.emit("class ASRPassBaseWalkVisitor : public BaseVisitor<Struct>")
416+
self.emit("{")
417+
self.emit("private:")
418+
self.emit(" Struct& self() { return static_cast<Struct&>(*this); }")
419+
self.emit("public:")
420+
self.emit(" SymbolTable* current_scope;")
421+
self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {")
422+
self.emit(" for (size_t i = 0; i < n_body; i++) {", 1)
423+
self.emit(" self().visit_stmt(*m_body[i]);", 1)
424+
self.emit(" }", 1)
425+
self.emit("}", 1)
426+
super(ASRPassWalkVisitorVisitor, self).visitModule(mod)
427+
self.emit("};")
428+
429+
def visitType(self, tp):
430+
if not (isinstance(tp.value, asdl.Sum) and
431+
is_simple_sum(tp.value)):
432+
super(ASRPassWalkVisitorVisitor, self).visitType(tp, tp.name)
433+
434+
def visitProduct(self, prod, name):
435+
self.make_visitor(name, prod.fields)
436+
437+
def visitConstructor(self, cons, _):
438+
self.make_visitor(cons.name, cons.fields)
439+
440+
def make_visitor(self, name, fields):
441+
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
442+
is_symtab_present = False
443+
is_stmt_present = False
444+
symtab_field_name = ""
445+
for field in fields:
446+
if field.type == "stmt":
447+
is_stmt_present = True
448+
if field.type == "symbol_table":
449+
is_symtab_present = True
450+
symtab_field_name = field.name
451+
if is_stmt_present and is_symtab_present:
452+
break
453+
if is_stmt_present and name not in ("Assignment", "ForAllSingle"):
454+
self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1)
455+
self.used = False
456+
457+
if is_symtab_present:
458+
self.emit("SymbolTable* current_scope_copy = current_scope;", 2)
459+
self.emit("current_scope = x.m_%s;" % symtab_field_name, 2)
460+
461+
for field in fields:
462+
self.visitField(field)
463+
if not self.used:
464+
# Note: a better solution would be to change `&x` to `& /* x */`
465+
# above, but we would need to change emit to return a string.
466+
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)
467+
468+
if is_symtab_present:
469+
self.emit("current_scope = current_scope_copy;", 2)
470+
471+
self.emit("}", 1)
472+
473+
def visitField(self, field):
474+
if (field.type not in asdl.builtin_types and
475+
field.type not in self.data.simple_types):
476+
level = 2
477+
if field.seq:
478+
if field.type == "stmt":
479+
self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level)
480+
return
481+
self.used = True
482+
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
483+
if field.type in products:
484+
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
485+
else:
486+
if field.type != "symbol":
487+
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
488+
self.emit("}", level)
489+
else:
490+
if field.type in products:
491+
self.used = True
492+
if field.opt:
493+
self.emit("if (x.m_%s)" % field.name, 2)
494+
level = 3
495+
if field.opt:
496+
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
497+
else:
498+
self.emit("self().visit_%s(x.m_%s);" % (field.type, field.name), level)
499+
else:
500+
if field.type != "symbol":
501+
self.used = True
502+
if field.opt:
503+
self.emit("if (x.m_%s)" % field.name, 2)
504+
level = 3
505+
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
506+
elif field.type == "symbol_table" and field.name in["symtab",
507+
"global_scope"]:
508+
self.used = True
509+
self.emit("for (auto &a : x.m_%s->get_scope()) {" % field.name, 2)
510+
self.emit( "this->visit_symbol(*a.second);", 3)
511+
self.emit("}", 2)
512+
408513
class CallReplacerOnExpressionsVisitor(ASDLVisitor):
409514

410515
def __init__(self, stream, data):
@@ -477,10 +582,10 @@ def make_visitor(self, name, fields):
477582
self.emit("}", 1)
478583

479584
def insert_call_replacer_code(self, name, level, index=""):
480-
self.emit(" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self.current_expr_copy_variable_count), level)
481-
self.emit(" current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
482-
self.emit(" self().call_replacer();", level)
483-
self.emit(" current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
585+
self.emit("ASR::expr_t** current_expr_copy_%d = current_expr;" % (self.current_expr_copy_variable_count), level)
586+
self.emit("current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
587+
self.emit("self().call_replacer();", level)
588+
self.emit("current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
484589
self.current_expr_copy_variable_count += 1
485590

486591
def visitField(self, field):
@@ -495,12 +600,14 @@ def visitField(self, field):
495600
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
496601
if field.type in products:
497602
if field.type == "expr":
498-
self.insert_call_replacer_code(field.name, level, "[i]")
603+
self.insert_call_replacer_code(field.name, level + 1, "[i]")
604+
self.emit("if( x.m_%s[i] )" % (field.name), level)
499605
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
500606
else:
501607
if field.type != "symbol":
502608
if field.type == "expr":
503-
self.insert_call_replacer_code(field.name, level, "[i]")
609+
self.insert_call_replacer_code(field.name, level + 1, "[i]")
610+
self.emit("if( x.m_%s[i] )" % (field.name), level + 1)
504611
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
505612
self.emit("}", level)
506613
else:
@@ -511,6 +618,7 @@ def visitField(self, field):
511618
level = 3
512619
if field.type == "expr":
513620
self.insert_call_replacer_code(field.name, level)
621+
self.emit("if( x.m_%s )" % (field.name), level)
514622
if field.opt:
515623
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
516624
self.emit("}", 2)
@@ -524,6 +632,7 @@ def visitField(self, field):
524632
level = 3
525633
if field.type == "expr":
526634
self.insert_call_replacer_code(field.name, level)
635+
self.emit("if( x.m_%s )" % (field.name), level)
527636
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
528637
if field.opt:
529638
self.emit("}", 2)
@@ -2595,6 +2704,8 @@ def main(argv):
25952704

25962705
try:
25972706
if is_asr:
2707+
ASRPassWalkVisitorVisitor(fp, data).visit(mod)
2708+
fp.write("\n\n")
25982709
ExprStmtDuplicatorVisitor(fp, data).visit(mod)
25992710
fp.write("\n\n")
26002711
ExprBaseReplacerVisitor(fp, data).visit(mod)

src/libasr/asr_scopes.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,21 @@ struct SymbolTable {
7171
scope.erase(name);
7272
}
7373

74+
// Add a new symbol that did not exist before
7475
void add_symbol(const std::string &name, ASR::symbol_t* symbol) {
76+
LCOMPILERS_ASSERT(scope.find(name) == scope.end())
77+
scope[name] = symbol;
78+
}
79+
80+
// Overwrite an existing symbol
81+
void overwrite_symbol(const std::string &name, ASR::symbol_t* symbol) {
82+
LCOMPILERS_ASSERT(scope.find(name) != scope.end())
83+
scope[name] = symbol;
84+
}
85+
86+
// Use as the last resort, prefer to always either add a new symbol
87+
// or overwrite an existing one, not both
88+
void add_or_overwrite_symbol(const std::string &name, ASR::symbol_t* symbol) {
7589
scope[name] = symbol;
7690
}
7791

0 commit comments

Comments
 (0)