Skip to content

Commit a27044c

Browse files
committed
Added code for visit_Module, visit_Function and visit_Variable
1 parent 42192ff commit a27044c

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

src/libasr/pass/replace_symbolic.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,151 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
2222
PassVisitor(al_, nullptr) {
2323
pass_result.reserve(al, 1);
2424
}
25+
26+
bool symbolic_replaces_with_CPtr_Module = false;
27+
bool symbolic_replaces_with_CPtr_Function = false;
28+
29+
void visit_Module(const ASR::Module_t &x) {
30+
SymbolTable* current_scope_copy = current_scope;
31+
current_scope = x.m_symtab;
32+
for (auto &a : x.m_symtab->get_scope()) {
33+
this->visit_symbol(*a.second);
34+
if(symbolic_replaces_with_CPtr_Module){
35+
std::string new_name = current_scope->get_unique_name("basic_new_stack");
36+
if(current_scope->get_symbol(new_name)) return;
37+
std::string header = "symengine/cwrapper.h";
38+
SymbolTable *fn_symtab = al.make_new<SymbolTable>(current_scope);
39+
40+
Vec<ASR::expr_t*> args;
41+
{
42+
args.reserve(al, 1);
43+
ASR::ttype_t *arg_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
44+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
45+
al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
46+
nullptr, nullptr, ASR::storage_typeType::Default, arg_type, nullptr,
47+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
48+
fn_symtab->add_symbol(s2c(al, "x"), arg);
49+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)));
50+
}
51+
Vec<ASR::stmt_t*> body;
52+
body.reserve(al, 1);
53+
54+
Vec<char *> dep;
55+
dep.reserve(al, 1);
56+
57+
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc,
58+
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
59+
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
60+
ASR::deftypeType::Interface, nullptr, false, false, false,
61+
false, false, nullptr, 0, nullptr, 0, false, false, false, s2c(al, header));
62+
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
63+
current_scope->add_symbol(new_name, new_symbol);
64+
symbolic_replaces_with_CPtr_Module = false;
65+
}
66+
}
67+
current_scope = current_scope_copy;
68+
}
69+
70+
void visit_Function(const ASR::Function_t& x) {
71+
ASR::Function_t &xx = const_cast<ASR::Function_t&>(x);
72+
SymbolTable* current_scope_copy = current_scope;
73+
current_scope = xx.m_symtab;
74+
for (auto &item : current_scope->get_scope()) {
75+
if (is_a<ASR::Variable_t>(*item.second)) {
76+
this->visit_symbol(*item.second);
77+
if(symbolic_replaces_with_CPtr_Function){
78+
std::string var = std::string(item.first);
79+
std::string placeholder = "_" + var;
80+
ASR::symbol_t* var_sym = current_scope->get_symbol(var);
81+
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
82+
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
83+
ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym));
84+
85+
// statement 1
86+
int cast_kind = ASR::cast_kindType::IntegerToInteger;
87+
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4));
88+
ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8));
89+
ASR::expr_t* cast_tar = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type1));
90+
ASR::expr_t* cast_val = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2));
91+
ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc, cast_tar, (ASR::cast_kindType)cast_kind, type2, cast_val));
92+
93+
// statement 2
94+
ASR::ttype_t *type3 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
95+
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type3));
96+
97+
// statement 3
98+
ASR::ttype_t *type4 = ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2));
99+
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc, target1, type4, nullptr));
100+
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node, type3, nullptr));
101+
102+
// statement 4
103+
ASR::symbol_t* basic_new_stack_sym = current_scope->parent->get_symbol("basic_new_stack");
104+
Vec<ASR::call_arg_t> call_args;
105+
call_args.reserve(al, 1);
106+
ASR::call_arg_t call_arg;
107+
call_arg.loc = xx.base.base.loc;
108+
call_arg.m_value = target2;
109+
call_args.push_back(al, call_arg);
110+
111+
// defining the assignment statement
112+
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
113+
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
114+
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
115+
//ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym, nullptr, call_args.p, call_args.n, nullptr));
116+
117+
// push stmt1 into the updated body vector
118+
pass_result.push_back(al, stmt1);
119+
pass_result.push_back(al, stmt2);
120+
pass_result.push_back(al, stmt3);
121+
//pass_result.push_back(al, stmt4);
122+
123+
// updated x.m_body and x.n_bdoy with that of the updated vector
124+
// x.m_body = updated_body.p;
125+
// x.n_body = updated_body.size();
126+
transform_stmts(xx.m_body, xx.n_body);
127+
symbolic_replaces_with_CPtr_Function = false;
128+
}
129+
}
130+
}
131+
current_scope = current_scope_copy;
132+
}
133+
134+
void visit_Variable(const ASR::Variable_t& x) {
135+
SymbolTable* current_scope_copy = current_scope;
136+
current_scope = x.m_parent_symtab;
137+
if (x.m_type->type == ASR::ttypeType::SymbolicExpression) {
138+
symbolic_replaces_with_CPtr_Module = true;
139+
symbolic_replaces_with_CPtr_Function = true;
140+
std::string var_name = x.m_name;
141+
std::string placeholder = "_" + std::string(x.m_name);
142+
143+
// defining CPtr variable
144+
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
145+
ASR::symbol_t* sym1 = ASR::down_cast<ASR::symbol_t>(
146+
ASR::make_Variable_t(al, x.base.base.loc, current_scope,
147+
s2c(al, var_name), nullptr, 0,
148+
x.m_intent, nullptr,
149+
nullptr, x.m_storage,
150+
type1, nullptr, x.m_abi,
151+
x.m_access, x.m_presence,
152+
x.m_value_attr));
153+
154+
ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
155+
ASR::symbol_t* sym2 = ASR::down_cast<ASR::symbol_t>(
156+
ASR::make_Variable_t(al, x.base.base.loc, current_scope,
157+
s2c(al, placeholder), nullptr, 0,
158+
x.m_intent, nullptr,
159+
nullptr, x.m_storage,
160+
type2, nullptr, x.m_abi,
161+
x.m_access, x.m_presence,
162+
x.m_value_attr));
163+
164+
current_scope->erase_symbol(s2c(al, var_name));
165+
current_scope->add_symbol(s2c(al, var_name), sym1);
166+
current_scope->add_symbol(s2c(al, placeholder), sym2);
167+
}
168+
current_scope = current_scope_copy;
169+
}
25170
};
26171

27172
void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit,

0 commit comments

Comments
 (0)