@@ -405,6 +405,111 @@ def visitField(self, field):
405
405
self .emit ( "this->visit_symbol(*a.second);" , 3 )
406
406
self .emit ("}" , 2 )
407
407
408
+ class ASRPassWalkVisitorVisitor (ASDLVisitor ):
409
+
410
+ def visitModule (self , mod ):
411
+ self .emit ("/" + "*" * 78 + "/" )
412
+ self .emit ("// Walk Visitor base class" )
413
+ self .emit ("" )
414
+ self .emit ("template <class Struct>" )
415
+ self .emit ("class ASRPassBaseWalkVisitor : public BaseVisitor<Struct>" )
416
+ self .emit ("{" )
417
+ self .emit ("private:" )
418
+ self .emit (" Struct& self() { return static_cast<Struct&>(*this); }" )
419
+ self .emit ("public:" )
420
+ self .emit (" SymbolTable* current_scope;" )
421
+ self .emit (" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {" )
422
+ self .emit (" for (size_t i = 0; i < n_body; i++) {" , 1 )
423
+ self .emit (" self().visit_stmt(*m_body[i]);" , 1 )
424
+ self .emit (" }" , 1 )
425
+ self .emit ("}" , 1 )
426
+ super (ASRPassWalkVisitorVisitor , self ).visitModule (mod )
427
+ self .emit ("};" )
428
+
429
+ def visitType (self , tp ):
430
+ if not (isinstance (tp .value , asdl .Sum ) and
431
+ is_simple_sum (tp .value )):
432
+ super (ASRPassWalkVisitorVisitor , self ).visitType (tp , tp .name )
433
+
434
+ def visitProduct (self , prod , name ):
435
+ self .make_visitor (name , prod .fields )
436
+
437
+ def visitConstructor (self , cons , _ ):
438
+ self .make_visitor (cons .name , cons .fields )
439
+
440
+ def make_visitor (self , name , fields ):
441
+ self .emit ("void visit_%s(const %s_t &x) {" % (name , name ), 1 )
442
+ is_symtab_present = False
443
+ is_stmt_present = False
444
+ symtab_field_name = ""
445
+ for field in fields :
446
+ if field .type == "stmt" :
447
+ is_stmt_present = True
448
+ if field .type == "symbol_table" :
449
+ is_symtab_present = True
450
+ symtab_field_name = field .name
451
+ if is_stmt_present and is_symtab_present :
452
+ break
453
+ if is_stmt_present and name not in ("Assignment" , "ForAllSingle" ):
454
+ self .emit (" %s_t& xx = const_cast<%s_t&>(x);" % (name , name ), 1 )
455
+ self .used = False
456
+
457
+ if is_symtab_present :
458
+ self .emit ("SymbolTable* current_scope_copy = current_scope;" , 2 )
459
+ self .emit ("current_scope = x.m_%s;" % symtab_field_name , 2 )
460
+
461
+ for field in fields :
462
+ self .visitField (field )
463
+ if not self .used :
464
+ # Note: a better solution would be to change `&x` to `& /* x */`
465
+ # above, but we would need to change emit to return a string.
466
+ self .emit ("if ((bool&)x) { } // Suppress unused warning" , 2 )
467
+
468
+ if is_symtab_present :
469
+ self .emit ("current_scope = current_scope_copy;" , 2 )
470
+
471
+ self .emit ("}" , 1 )
472
+
473
+ def visitField (self , field ):
474
+ if (field .type not in asdl .builtin_types and
475
+ field .type not in self .data .simple_types ):
476
+ level = 2
477
+ if field .seq :
478
+ if field .type == "stmt" :
479
+ self .emit ("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field .name , field .name ), level )
480
+ return
481
+ self .used = True
482
+ self .emit ("for (size_t i=0; i<x.n_%s; i++) {" % field .name , level )
483
+ if field .type in products :
484
+ self .emit (" self().visit_%s(x.m_%s[i]);" % (field .type , field .name ), level )
485
+ else :
486
+ if field .type != "symbol" :
487
+ self .emit (" self().visit_%s(*x.m_%s[i]);" % (field .type , field .name ), level )
488
+ self .emit ("}" , level )
489
+ else :
490
+ if field .type in products :
491
+ self .used = True
492
+ if field .opt :
493
+ self .emit ("if (x.m_%s)" % field .name , 2 )
494
+ level = 3
495
+ if field .opt :
496
+ self .emit ("self().visit_%s(*x.m_%s);" % (field .type , field .name ), level )
497
+ else :
498
+ self .emit ("self().visit_%s(x.m_%s);" % (field .type , field .name ), level )
499
+ else :
500
+ if field .type != "symbol" :
501
+ self .used = True
502
+ if field .opt :
503
+ self .emit ("if (x.m_%s)" % field .name , 2 )
504
+ level = 3
505
+ self .emit ("self().visit_%s(*x.m_%s);" % (field .type , field .name ), level )
506
+ elif field .type == "symbol_table" and field .name in ["symtab" ,
507
+ "global_scope" ]:
508
+ self .used = True
509
+ self .emit ("for (auto &a : x.m_%s->get_scope()) {" % field .name , 2 )
510
+ self .emit ( "this->visit_symbol(*a.second);" , 3 )
511
+ self .emit ("}" , 2 )
512
+
408
513
class CallReplacerOnExpressionsVisitor (ASDLVisitor ):
409
514
410
515
def __init__ (self , stream , data ):
@@ -477,10 +582,10 @@ def make_visitor(self, name, fields):
477
582
self .emit ("}" , 1 )
478
583
479
584
def insert_call_replacer_code (self , name , level , index = "" ):
480
- self .emit (" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self .current_expr_copy_variable_count ), level )
481
- self .emit (" current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name , index ), level )
482
- self .emit (" self().call_replacer();" , level )
483
- self .emit (" current_expr = current_expr_copy_%d;" % (self .current_expr_copy_variable_count ), level )
585
+ self .emit ("ASR::expr_t** current_expr_copy_%d = current_expr;" % (self .current_expr_copy_variable_count ), level )
586
+ self .emit ("current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name , index ), level )
587
+ self .emit ("self().call_replacer();" , level )
588
+ self .emit ("current_expr = current_expr_copy_%d;" % (self .current_expr_copy_variable_count ), level )
484
589
self .current_expr_copy_variable_count += 1
485
590
486
591
def visitField (self , field ):
@@ -495,12 +600,14 @@ def visitField(self, field):
495
600
self .emit ("for (size_t i=0; i<x.n_%s; i++) {" % field .name , level )
496
601
if field .type in products :
497
602
if field .type == "expr" :
498
- self .insert_call_replacer_code (field .name , level , "[i]" )
603
+ self .insert_call_replacer_code (field .name , level + 1 , "[i]" )
604
+ self .emit ("if( x.m_%s[i] )" % (field .name ), level )
499
605
self .emit (" self().visit_%s(x.m_%s[i]);" % (field .type , field .name ), level )
500
606
else :
501
607
if field .type != "symbol" :
502
608
if field .type == "expr" :
503
- self .insert_call_replacer_code (field .name , level , "[i]" )
609
+ self .insert_call_replacer_code (field .name , level + 1 , "[i]" )
610
+ self .emit ("if( x.m_%s[i] )" % (field .name ), level + 1 )
504
611
self .emit (" self().visit_%s(*x.m_%s[i]);" % (field .type , field .name ), level )
505
612
self .emit ("}" , level )
506
613
else :
@@ -511,6 +618,7 @@ def visitField(self, field):
511
618
level = 3
512
619
if field .type == "expr" :
513
620
self .insert_call_replacer_code (field .name , level )
621
+ self .emit ("if( x.m_%s )" % (field .name ), level )
514
622
if field .opt :
515
623
self .emit ("self().visit_%s(*x.m_%s);" % (field .type , field .name ), level )
516
624
self .emit ("}" , 2 )
@@ -524,6 +632,7 @@ def visitField(self, field):
524
632
level = 3
525
633
if field .type == "expr" :
526
634
self .insert_call_replacer_code (field .name , level )
635
+ self .emit ("if( x.m_%s )" % (field .name ), level )
527
636
self .emit ("self().visit_%s(*x.m_%s);" % (field .type , field .name ), level )
528
637
if field .opt :
529
638
self .emit ("}" , 2 )
@@ -2595,6 +2704,8 @@ def main(argv):
2595
2704
2596
2705
try :
2597
2706
if is_asr :
2707
+ ASRPassWalkVisitorVisitor (fp , data ).visit (mod )
2708
+ fp .write ("\n \n " )
2598
2709
ExprStmtDuplicatorVisitor (fp , data ).visit (mod )
2599
2710
fp .write ("\n \n " )
2600
2711
ExprBaseReplacerVisitor (fp , data ).visit (mod )
0 commit comments