@@ -422,8 +422,14 @@ def visitModule(self, mod):
422
422
self .emit (" Struct& self() { return static_cast<Struct&>(*this); }" )
423
423
self .emit ("public:" )
424
424
self .emit (" ASR::expr_t** current_expr;" )
425
+ self .emit (" SymbolTable* current_scope;" )
425
426
self .emit ("" )
426
427
self .emit (" void call_replacer() {}" )
428
+ self .emit (" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {" )
429
+ self .emit (" for (size_t i = 0; i < n_body; i++) {" , 1 )
430
+ self .emit (" self().visit_stmt(*m_body[i]);" , 1 )
431
+ self .emit (" }" , 1 )
432
+ self .emit (" }" )
427
433
super (CallReplacerOnExpressionsVisitor , self ).visitModule (mod )
428
434
self .emit ("};" )
429
435
@@ -440,14 +446,34 @@ def visitConstructor(self, cons, _):
440
446
441
447
def make_visitor (self , name , fields ):
442
448
self .emit ("void visit_%s(const %s_t &x) {" % (name , name ), 1 )
449
+ is_symtab_present = False
450
+ is_stmt_present = False
451
+ symtab_field_name = ""
452
+ for field in fields :
453
+ if field .type == "stmt" :
454
+ is_stmt_present = True
455
+ if field .type == "symbol_table" :
456
+ is_symtab_present = True
457
+ symtab_field_name = field .name
458
+ if is_stmt_present and is_symtab_present :
459
+ break
460
+ if is_stmt_present and name not in ("Assignment" , "ForAllSingle" ):
461
+ self .emit (" %s_t& xx = const_cast<%s_t&>(x);" % (name , name ), 1 )
443
462
self .used = False
444
- have_body = False
463
+
464
+ if is_symtab_present :
465
+ self .emit ("SymbolTable* current_scope_copy = current_scope;" , 2 )
466
+ self .emit ("current_scope = x.m_%s;" % symtab_field_name , 2 )
467
+
445
468
for field in fields :
446
469
self .visitField (field )
447
470
if not self .used :
448
471
# Note: a better solution would be to change `&x` to `& /* x */`
449
472
# above, but we would need to change emit to return a string.
450
473
self .emit ("if ((bool&)x) { } // Suppress unused warning" , 2 )
474
+
475
+ if is_symtab_present :
476
+ self .emit ("current_scope = current_scope_copy;" , 2 )
451
477
self .emit ("}" , 1 )
452
478
453
479
def insert_call_replacer_code (self , name , level , index = "" ):
@@ -462,6 +488,9 @@ def visitField(self, field):
462
488
field .type not in self .data .simple_types ):
463
489
level = 2
464
490
if field .seq :
491
+ if field .type == "stmt" :
492
+ self .emit ("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field .name , field .name ), level )
493
+ return
465
494
self .used = True
466
495
self .emit ("for (size_t i=0; i<x.n_%s; i++) {" % field .name , level )
467
496
if field .type in products :
0 commit comments