Skip to content

Update ASR from LFortran #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ symbol
| ClassType(symbol_table symtab, identifier name, abi abi, access access)
| ClassProcedure(symbol_table parent_symtab, identifier name, identifier
proc_name, symbol proc, abi abi)
| AssociateBlock(symbol_table symtab, identifier name, stmt* body)

storage_type = Default | Save | Parameter | Allocatable
access = Public | Private
intent = Local | In | Out | InOut | ReturnVar | Unspecified | AssociateBlock
intent = Local | In | Out | InOut | ReturnVar | Unspecified
deftype = Implementation | Interface
presence = Required | Optional

Expand Down Expand Up @@ -199,6 +200,7 @@ stmt
| Nullify(symbol* vars)
| Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat)
| ListAppend(symbol a, expr ele)
| AssociateBlockCall(symbol m)
| SetInsert(symbol a, expr ele)
| SetRemove(symbol a, expr ele)
| ListInsert(symbol a, expr pos, expr ele)
Expand Down Expand Up @@ -245,6 +247,7 @@ expr
| DictLen(expr arg, ttype type, expr? value)
| Var(symbol v)
| ArrayRef(symbol v, array_index* args, ttype type, expr? value)
| ArraySize(expr v, expr? dim, ttype type, expr? value)
| DerivedRef(expr v, symbol m, ttype type, expr? value)
| Cast(expr arg, cast_kind kind, ttype type, expr? value)
| ComplexRe(expr arg, ttype type, expr? value)
Expand Down
11 changes: 11 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ static inline ASR::ttype_t* expr_type(const ASR::expr_t *f)
case ASR::exprType::IntegerBOZ: { return ((ASR::IntegerBOZ_t*)f)->m_type; }
case ASR::exprType::Var: { return EXPR2VAR(f)->m_type; }
case ASR::exprType::ArrayRef: { return ((ASR::ArrayRef_t*)f)->m_type; }
case ASR::exprType::ArraySize: { return ((ASR::ArraySize_t*)f)->m_type; }
case ASR::exprType::DerivedRef: { return ((ASR::DerivedRef_t*)f)->m_type; }
case ASR::exprType::Cast: { return ((ASR::Cast_t*)f)->m_type; }
case ASR::exprType::ComplexRe: { return ((ASR::ComplexRe_t*)f)->m_type; }
Expand Down Expand Up @@ -270,6 +271,7 @@ static inline ASR::expr_t* expr_value(ASR::expr_t *f)
case ASR::exprType::Compare: { return ASR::down_cast<ASR::Compare_t>(f)->m_value; }
case ASR::exprType::FunctionCall: { return ASR::down_cast<ASR::FunctionCall_t>(f)->m_value; }
case ASR::exprType::ArrayRef: { return ASR::down_cast<ASR::ArrayRef_t>(f)->m_value; }
case ASR::exprType::ArraySize: { return ASR::down_cast<ASR::ArraySize_t>(f)->m_value; }
case ASR::exprType::DerivedRef: { return ASR::down_cast<ASR::DerivedRef_t>(f)->m_value; }
case ASR::exprType::Cast: { return ASR::down_cast<ASR::Cast_t>(f)->m_value; }
case ASR::exprType::Var: { return EXPR2VAR(f)->m_value; }
Expand Down Expand Up @@ -331,6 +333,9 @@ static inline char *symbol_name(const ASR::symbol_t *f)
case ASR::symbolType::CustomOperator: {
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_name;
}
case ASR::symbolType::AssociateBlock: {
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_name;
}
default : throw LFortranException("Not implemented");
}
}
Expand Down Expand Up @@ -368,6 +373,9 @@ static inline SymbolTable *symbol_parent_symtab(const ASR::symbol_t *f)
case ASR::symbolType::CustomOperator: {
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_parent_symtab;
}
case ASR::symbolType::AssociateBlock: {
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_symtab->parent;
}
default : throw LFortranException("Not implemented");
}
}
Expand Down Expand Up @@ -407,6 +415,9 @@ static inline SymbolTable *symbol_symtab(const ASR::symbol_t *f)
return nullptr;
//throw LFortranException("ClassProcedure does not have a symtab");
}
case ASR::symbolType::AssociateBlock: {
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_symtab;
}
default : throw LFortranException("Not implemented");
}
}
Expand Down
23 changes: 23 additions & 0 deletions src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,29 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
current_symtab = parent_symtab;
}

void visit_AssociateBlock(const AssociateBlock_t& x) {
SymbolTable *parent_symtab = current_symtab;
current_symtab = x.m_symtab;
require(x.m_symtab != nullptr,
"The AssociateBlock::m_symtab cannot be nullptr");
require(x.m_symtab->parent == parent_symtab,
"The AssociateBlock::m_symtab->parent is not the right parent");
require(id_symtab_map.find(x.m_symtab->counter) == id_symtab_map.end(),
"AssociateBlock::m_symtab->counter must be unique");
require(x.m_symtab->asr_owner == (ASR::asr_t*)&x,
"The X::m_symtab::asr_owner must point to X");
require(ASRUtils::symbol_symtab(down_cast<symbol_t>(current_symtab->asr_owner)) == current_symtab,
"The asr_owner invariant failed");
id_symtab_map[x.m_symtab->counter] = x.m_symtab;
for (auto &a : x.m_symtab->scope) {
this->visit_symbol(*a.second);
}
for (size_t i=0; i<x.n_body; i++) {
visit_stmt(*x.m_body[i]);
}
current_symtab = parent_symtab;
}

void visit_Module(const Module_t &x) {
SymbolTable *parent_symtab = current_symtab;
current_symtab = x.m_symtab;
Expand Down
18 changes: 18 additions & 0 deletions src/libasr/codegen/asr_to_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,24 @@ Kokkos::View<T*> from_std_vector(const std::vector<T> &v)
last_expr_precedence = 2;
}

void visit_ArraySize(const ASR::ArraySize_t& x) {
visit_expr(*x.m_v);
std::string var_name = src;
std::string args = "";
if (x.m_dim == nullptr) {
// TODO: return the product of all dimensions:
args = "0";
} else {
if( x.m_dim ) {
visit_expr(*x.m_dim);
args += src + "-1";
args += ", ";
}
args += std::to_string(ASRUtils::extract_kind_from_ttype_t(x.m_type)) + "-1";
}
src = var_name + ".extent(" + args + ")";
}

void visit_Assignment(const ASR::Assignment_t &x) {
std::string target;
if (ASR::is_a<ASR::Var_t>(*x.m_target)) {
Expand Down
133 changes: 60 additions & 73 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,10 +880,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
complex_type_8_ptr = llvm::StructType::create(context, els_8_ptr, "complex_8_ptr");
character_type = llvm::Type::getInt8PtrTy(context);

llvm::Type* size_arg = (llvm::Type*)llvm::StructType::create(context, std::vector<llvm::Type*>({
arr_descr->get_dimension_descriptor_type(true),
getIntType(4)}), "size_arg");
fname2arg_type["size"] = std::make_pair(size_arg, size_arg->getPointerTo());
llvm::Type* bound_arg = static_cast<llvm::Type*>(arr_descr->get_dimension_descriptor_type(true));
fname2arg_type["lbound"] = std::make_pair(bound_arg, bound_arg->getPointerTo());
fname2arg_type["ubound"] = std::make_pair(bound_arg, bound_arg->getPointerTo());
Expand Down Expand Up @@ -2356,50 +2352,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
} else if( x.m_abi == ASR::abiType::Intrinsic &&
x.m_deftype == ASR::deftypeType::Interface ) {
std::string m_name = x.m_name;
if( m_name == "size" ) {

define_function_entry(x);

// Defines the size intrinsic's body at LLVM level.
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0]);
uint32_t h = get_hash((ASR::asr_t*)arg);
llvm::Value* llvm_arg = llvm_symtab[h];
ASR::Variable_t *ret = EXPR2VAR(x.m_return_var);
h = get_hash((ASR::asr_t*)ret);
llvm::Value* llvm_ret_ptr = llvm_symtab[h];
llvm::Value* dim_des_val = CreateLoad(llvm_utils->create_gep(llvm_arg, 0));
llvm::Value* rank = CreateLoad(llvm_utils->create_gep(llvm_arg, 1));
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), llvm_ret_ptr);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
this->current_loophead = loophead;
this->current_loopend = loopend;

llvm::Value* r = builder->CreateAlloca(getIntType(4), nullptr);
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), r);
// head
start_new_block(loophead);
llvm::Value *cond = builder->CreateICmpSLT(CreateLoad(r), rank);
builder->CreateCondBr(cond, loopbody, loopend);

// body
start_new_block(loopbody);
llvm::Value* r_val = CreateLoad(r);
llvm::Value* ret_val = CreateLoad(llvm_ret_ptr);
llvm::Value* dim_size = arr_descr->get_dimension_size(dim_des_val, r_val);
ret_val = builder->CreateMul(ret_val, dim_size);
builder->CreateStore(ret_val, llvm_ret_ptr);
r_val = builder->CreateAdd(r_val, llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
builder->CreateStore(r_val, r);
builder->CreateBr(loophead);

// end
start_new_block(loopend);

define_function_exit(x);
} else if( m_name == "lbound" || m_name == "ubound" ) {
if( m_name == "lbound" || m_name == "ubound" ) {
define_function_entry(x);

// Defines the size intrinsic's body at LLVM level.
Expand Down Expand Up @@ -2532,6 +2485,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}

void visit_AssociateBlockCall(const ASR::AssociateBlockCall_t& x) {
LFORTRAN_ASSERT(ASR::is_a<ASR::AssociateBlock_t>(*x.m_m));
ASR::AssociateBlock_t* associate_block = ASR::down_cast<ASR::AssociateBlock_t>(x.m_m);
declare_vars(*associate_block);
for (size_t i = 0; i < associate_block->n_body; i++) {
this->visit_stmt(*(associate_block->m_body[i]));
}
}

inline void visit_expr_wrapper(const ASR::expr_t* x, bool load_ref=false) {
this->visit_expr(*x);
if( x->type == ASR::exprType::ArrayRef ||
Expand Down Expand Up @@ -3844,31 +3806,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
x_abi = sub->m_abi;
}
if( x_abi == ASR::abiType::Intrinsic ) {
if( name == "size" ) {
/*
When size intrinsic is called on a fortran array then the above
code extracts the dimension descriptor array and its rank from the
overall array descriptor. It wraps them into a struct (specifically, arg_struct of type, size_arg here)
and passes to LLVM size. So, if you do, size(a) (a is a fortran array), then at LLVM level,
@size(%size_arg* %x) is used as call where size_arg
is described above.
*/
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0].m_value);
uint32_t h = get_hash((ASR::asr_t*)arg);
tmp = llvm_symtab[h];
llvm::Value* arg_struct = builder->CreateAlloca(fname2arg_type["size"].first, nullptr);
llvm::Value* first_ele_ptr = arr_descr->get_pointer_to_dimension_descriptor_array(tmp);
llvm::Value* first_arg_ptr = llvm_utils->create_gep(arg_struct, 0);
builder->CreateStore(first_ele_ptr, first_arg_ptr);
llvm::Value* rank_ptr = llvm_utils->create_gep(arg_struct, 1);
builder->CreateStore(arr_descr->get_rank(tmp), rank_ptr);
tmp = arg_struct;
args.push_back(tmp);
llvm::Value* dim = builder->CreateAlloca(getIntType(4));
args.push_back(dim);
llvm::Value* kind = builder->CreateAlloca(getIntType(4));
args.push_back(kind);
} else if( name == "lbound" || name == "ubound" ) {
if( name == "lbound" || name == "ubound" ) {
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0].m_value);
uint32_t h = get_hash((ASR::asr_t*)arg);
tmp = llvm_symtab[h];
Expand Down Expand Up @@ -4374,6 +4312,55 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
pop_nested_stack(s);
}

void visit_ArraySize(const ASR::ArraySize_t& x) {
if( x.m_value ) {
visit_expr_wrapper(x.m_value, true);
return ;
}
visit_expr_wrapper(x.m_v);
llvm::Value* llvm_arg = tmp;
llvm::Value* dim_des_val = arr_descr->get_pointer_to_dimension_descriptor_array(llvm_arg);
if( x.m_dim ) {
visit_expr_wrapper(x.m_dim, true);
int kind = ASRUtils::extract_kind_from_ttype_t(ASRUtils::expr_type(x.m_dim));
tmp = builder->CreateSub(tmp, llvm::ConstantInt::get(context, llvm::APInt(kind * 8, 1)));
tmp = arr_descr->get_dimension_size(dim_des_val, tmp);
return ;
}
llvm::Value* rank = arr_descr->get_rank(llvm_arg);
llvm::Value* llvm_size = builder->CreateAlloca(getIntType(ASRUtils::extract_kind_from_ttype_t(x.m_type)), nullptr);
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), llvm_size);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
this->current_loophead = loophead;
this->current_loopend = loopend;

llvm::Value* r = builder->CreateAlloca(getIntType(4), nullptr);
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), r);
// head
start_new_block(loophead);
llvm::Value *cond = builder->CreateICmpSLT(CreateLoad(r), rank);
builder->CreateCondBr(cond, loopbody, loopend);

// body
start_new_block(loopbody);
llvm::Value* r_val = CreateLoad(r);
llvm::Value* ret_val = CreateLoad(llvm_size);
llvm::Value* dim_size = arr_descr->get_dimension_size(dim_des_val, r_val);
ret_val = builder->CreateMul(ret_val, dim_size);
builder->CreateStore(ret_val, llvm_size);
r_val = builder->CreateAdd(r_val, llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
builder->CreateStore(r_val, r);
builder->CreateBr(loophead);

// end
start_new_block(loopend);

tmp = CreateLoad(llvm_size);
}

};


Expand Down
4 changes: 4 additions & 0 deletions src/libasr/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,10 @@ class ArrayOpVisitor : public PassUtils::PassVisitor<ArrayOpVisitor>
visit_ArrayOpCommon<ASR::BoolOp_t>(x, "_bool_op_res");
}

void visit_ArraySize(const ASR::ArraySize_t& x) {
tmp_val = const_cast<ASR::expr_t*>(&(x.base));
}

void visit_FunctionCall(const ASR::FunctionCall_t& x) {
tmp_val = const_cast<ASR::expr_t*>(&(x.base));
std::string x_name;
Expand Down
7 changes: 0 additions & 7 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,6 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
"'" + var_name + "' is undeclared");
throw SemanticAbort();
}
if( v->type == ASR::symbolType::Variable ) {
ASR::Variable_t* v_var = ASR::down_cast<ASR::Variable_t>(v);
if( v_var->m_type == nullptr &&
v_var->m_intent == ASR::intentType::AssociateBlock ) {
return (ASR::asr_t*)(v_var->m_symbolic_value);
}
}
return ASR::make_Var_t(al, loc, v);
}

Expand Down