Skip to content

Commit 657bfd8

Browse files
committed
Update ASR from LFortran
1 parent 7726b87 commit 657bfd8

File tree

6 files changed

+120
-74
lines changed

6 files changed

+120
-74
lines changed

src/libasr/ASR.asdl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ symbol
105105
| ClassType(symbol_table symtab, identifier name, abi abi, access access)
106106
| ClassProcedure(symbol_table parent_symtab, identifier name, identifier
107107
proc_name, symbol proc, abi abi)
108+
| AssociateBlock(symbol_table symtab, identifier name, stmt* body)
108109

109110
storage_type = Default | Save | Parameter | Allocatable
110111
access = Public | Private
111-
intent = Local | In | Out | InOut | ReturnVar | Unspecified | AssociateBlock
112+
intent = Local | In | Out | InOut | ReturnVar | Unspecified
112113
deftype = Implementation | Interface
113114
presence = Required | Optional
114115

@@ -199,6 +200,7 @@ stmt
199200
| Nullify(symbol* vars)
200201
| Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat)
201202
| ListAppend(symbol a, expr ele)
203+
| AssociateBlockCall(symbol m)
202204
| SetInsert(symbol a, expr ele)
203205
| SetRemove(symbol a, expr ele)
204206
| ListInsert(symbol a, expr pos, expr ele)
@@ -245,6 +247,7 @@ expr
245247
| DictLen(expr arg, ttype type, expr? value)
246248
| Var(symbol v)
247249
| ArrayRef(symbol v, array_index* args, ttype type, expr? value)
250+
| ArraySize(expr v, expr? dim, ttype type, expr? value)
248251
| DerivedRef(expr v, symbol m, ttype type, expr? value)
249252
| Cast(expr arg, cast_kind kind, ttype type, expr? value)
250253
| ComplexRe(expr arg, ttype type, expr? value)

src/libasr/asr_utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ static inline ASR::ttype_t* expr_type(const ASR::expr_t *f)
123123
case ASR::exprType::IntegerBOZ: { return ((ASR::IntegerBOZ_t*)f)->m_type; }
124124
case ASR::exprType::Var: { return EXPR2VAR(f)->m_type; }
125125
case ASR::exprType::ArrayRef: { return ((ASR::ArrayRef_t*)f)->m_type; }
126+
case ASR::exprType::ArraySize: { return ((ASR::ArraySize_t*)f)->m_type; }
126127
case ASR::exprType::DerivedRef: { return ((ASR::DerivedRef_t*)f)->m_type; }
127128
case ASR::exprType::Cast: { return ((ASR::Cast_t*)f)->m_type; }
128129
case ASR::exprType::ComplexRe: { return ((ASR::ComplexRe_t*)f)->m_type; }
@@ -270,6 +271,7 @@ static inline ASR::expr_t* expr_value(ASR::expr_t *f)
270271
case ASR::exprType::Compare: { return ASR::down_cast<ASR::Compare_t>(f)->m_value; }
271272
case ASR::exprType::FunctionCall: { return ASR::down_cast<ASR::FunctionCall_t>(f)->m_value; }
272273
case ASR::exprType::ArrayRef: { return ASR::down_cast<ASR::ArrayRef_t>(f)->m_value; }
274+
case ASR::exprType::ArraySize: { return ASR::down_cast<ASR::ArraySize_t>(f)->m_value; }
273275
case ASR::exprType::DerivedRef: { return ASR::down_cast<ASR::DerivedRef_t>(f)->m_value; }
274276
case ASR::exprType::Cast: { return ASR::down_cast<ASR::Cast_t>(f)->m_value; }
275277
case ASR::exprType::Var: { return EXPR2VAR(f)->m_value; }
@@ -331,6 +333,9 @@ static inline char *symbol_name(const ASR::symbol_t *f)
331333
case ASR::symbolType::CustomOperator: {
332334
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_name;
333335
}
336+
case ASR::symbolType::AssociateBlock: {
337+
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_name;
338+
}
334339
default : throw LFortranException("Not implemented");
335340
}
336341
}
@@ -368,6 +373,9 @@ static inline SymbolTable *symbol_parent_symtab(const ASR::symbol_t *f)
368373
case ASR::symbolType::CustomOperator: {
369374
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_parent_symtab;
370375
}
376+
case ASR::symbolType::AssociateBlock: {
377+
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_symtab->parent;
378+
}
371379
default : throw LFortranException("Not implemented");
372380
}
373381
}
@@ -407,6 +415,9 @@ static inline SymbolTable *symbol_symtab(const ASR::symbol_t *f)
407415
return nullptr;
408416
//throw LFortranException("ClassProcedure does not have a symtab");
409417
}
418+
case ASR::symbolType::AssociateBlock: {
419+
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_symtab;
420+
}
410421
default : throw LFortranException("Not implemented");
411422
}
412423
}

src/libasr/asr_verify.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,29 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
151151
current_symtab = parent_symtab;
152152
}
153153

154+
void visit_AssociateBlock(const AssociateBlock_t& x) {
155+
SymbolTable *parent_symtab = current_symtab;
156+
current_symtab = x.m_symtab;
157+
require(x.m_symtab != nullptr,
158+
"The AssociateBlock::m_symtab cannot be nullptr");
159+
require(x.m_symtab->parent == parent_symtab,
160+
"The AssociateBlock::m_symtab->parent is not the right parent");
161+
require(id_symtab_map.find(x.m_symtab->counter) == id_symtab_map.end(),
162+
"AssociateBlock::m_symtab->counter must be unique");
163+
require(x.m_symtab->asr_owner == (ASR::asr_t*)&x,
164+
"The X::m_symtab::asr_owner must point to X");
165+
require(ASRUtils::symbol_symtab(down_cast<symbol_t>(current_symtab->asr_owner)) == current_symtab,
166+
"The asr_owner invariant failed");
167+
id_symtab_map[x.m_symtab->counter] = x.m_symtab;
168+
for (auto &a : x.m_symtab->scope) {
169+
this->visit_symbol(*a.second);
170+
}
171+
for (size_t i=0; i<x.n_body; i++) {
172+
visit_stmt(*x.m_body[i]);
173+
}
174+
current_symtab = parent_symtab;
175+
}
176+
154177
void visit_Module(const Module_t &x) {
155178
SymbolTable *parent_symtab = current_symtab;
156179
current_symtab = x.m_symtab;

src/libasr/codegen/asr_to_cpp.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,24 @@ Kokkos::View<T*> from_std_vector(const std::vector<T> &v)
523523
last_expr_precedence = 2;
524524
}
525525

526+
void visit_ArraySize(const ASR::ArraySize_t& x) {
527+
visit_expr(*x.m_v);
528+
std::string var_name = src;
529+
std::string args = "";
530+
if (x.m_dim == nullptr) {
531+
// TODO: return the product of all dimensions:
532+
args = "0";
533+
} else {
534+
if( x.m_dim ) {
535+
visit_expr(*x.m_dim);
536+
args += src + "-1";
537+
args += ", ";
538+
}
539+
args += std::to_string(ASRUtils::extract_kind_from_ttype_t(x.m_type)) + "-1";
540+
}
541+
src = var_name + ".extent(" + args + ")";
542+
}
543+
526544
void visit_Assignment(const ASR::Assignment_t &x) {
527545
std::string target;
528546
if (ASR::is_a<ASR::Var_t>(*x.m_target)) {

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 60 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -880,10 +880,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
880880
complex_type_8_ptr = llvm::StructType::create(context, els_8_ptr, "complex_8_ptr");
881881
character_type = llvm::Type::getInt8PtrTy(context);
882882

883-
llvm::Type* size_arg = (llvm::Type*)llvm::StructType::create(context, std::vector<llvm::Type*>({
884-
arr_descr->get_dimension_descriptor_type(true),
885-
getIntType(4)}), "size_arg");
886-
fname2arg_type["size"] = std::make_pair(size_arg, size_arg->getPointerTo());
887883
llvm::Type* bound_arg = static_cast<llvm::Type*>(arr_descr->get_dimension_descriptor_type(true));
888884
fname2arg_type["lbound"] = std::make_pair(bound_arg, bound_arg->getPointerTo());
889885
fname2arg_type["ubound"] = std::make_pair(bound_arg, bound_arg->getPointerTo());
@@ -2356,50 +2352,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
23562352
} else if( x.m_abi == ASR::abiType::Intrinsic &&
23572353
x.m_deftype == ASR::deftypeType::Interface ) {
23582354
std::string m_name = x.m_name;
2359-
if( m_name == "size" ) {
2360-
2361-
define_function_entry(x);
2362-
2363-
// Defines the size intrinsic's body at LLVM level.
2364-
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0]);
2365-
uint32_t h = get_hash((ASR::asr_t*)arg);
2366-
llvm::Value* llvm_arg = llvm_symtab[h];
2367-
ASR::Variable_t *ret = EXPR2VAR(x.m_return_var);
2368-
h = get_hash((ASR::asr_t*)ret);
2369-
llvm::Value* llvm_ret_ptr = llvm_symtab[h];
2370-
llvm::Value* dim_des_val = CreateLoad(llvm_utils->create_gep(llvm_arg, 0));
2371-
llvm::Value* rank = CreateLoad(llvm_utils->create_gep(llvm_arg, 1));
2372-
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), llvm_ret_ptr);
2373-
2374-
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
2375-
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
2376-
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
2377-
this->current_loophead = loophead;
2378-
this->current_loopend = loopend;
2379-
2380-
llvm::Value* r = builder->CreateAlloca(getIntType(4), nullptr);
2381-
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), r);
2382-
// head
2383-
start_new_block(loophead);
2384-
llvm::Value *cond = builder->CreateICmpSLT(CreateLoad(r), rank);
2385-
builder->CreateCondBr(cond, loopbody, loopend);
2386-
2387-
// body
2388-
start_new_block(loopbody);
2389-
llvm::Value* r_val = CreateLoad(r);
2390-
llvm::Value* ret_val = CreateLoad(llvm_ret_ptr);
2391-
llvm::Value* dim_size = arr_descr->get_dimension_size(dim_des_val, r_val);
2392-
ret_val = builder->CreateMul(ret_val, dim_size);
2393-
builder->CreateStore(ret_val, llvm_ret_ptr);
2394-
r_val = builder->CreateAdd(r_val, llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
2395-
builder->CreateStore(r_val, r);
2396-
builder->CreateBr(loophead);
2397-
2398-
// end
2399-
start_new_block(loopend);
2400-
2401-
define_function_exit(x);
2402-
} else if( m_name == "lbound" || m_name == "ubound" ) {
2355+
if( m_name == "lbound" || m_name == "ubound" ) {
24032356
define_function_entry(x);
24042357

24052358
// Defines the size intrinsic's body at LLVM level.
@@ -2532,6 +2485,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
25322485
}
25332486
}
25342487

2488+
void visit_AssociateBlockCall(const ASR::AssociateBlockCall_t& x) {
2489+
LFORTRAN_ASSERT(ASR::is_a<ASR::AssociateBlock_t>(*x.m_m));
2490+
ASR::AssociateBlock_t* associate_block = ASR::down_cast<ASR::AssociateBlock_t>(x.m_m);
2491+
declare_vars(*associate_block);
2492+
for (size_t i = 0; i < associate_block->n_body; i++) {
2493+
this->visit_stmt(*(associate_block->m_body[i]));
2494+
}
2495+
}
2496+
25352497
inline void visit_expr_wrapper(const ASR::expr_t* x, bool load_ref=false) {
25362498
this->visit_expr(*x);
25372499
if( x->type == ASR::exprType::ArrayRef ||
@@ -3844,31 +3806,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38443806
x_abi = sub->m_abi;
38453807
}
38463808
if( x_abi == ASR::abiType::Intrinsic ) {
3847-
if( name == "size" ) {
3848-
/*
3849-
When size intrinsic is called on a fortran array then the above
3850-
code extracts the dimension descriptor array and its rank from the
3851-
overall array descriptor. It wraps them into a struct (specifically, arg_struct of type, size_arg here)
3852-
and passes to LLVM size. So, if you do, size(a) (a is a fortran array), then at LLVM level,
3853-
@size(%size_arg* %x) is used as call where size_arg
3854-
is described above.
3855-
*/
3856-
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0].m_value);
3857-
uint32_t h = get_hash((ASR::asr_t*)arg);
3858-
tmp = llvm_symtab[h];
3859-
llvm::Value* arg_struct = builder->CreateAlloca(fname2arg_type["size"].first, nullptr);
3860-
llvm::Value* first_ele_ptr = arr_descr->get_pointer_to_dimension_descriptor_array(tmp);
3861-
llvm::Value* first_arg_ptr = llvm_utils->create_gep(arg_struct, 0);
3862-
builder->CreateStore(first_ele_ptr, first_arg_ptr);
3863-
llvm::Value* rank_ptr = llvm_utils->create_gep(arg_struct, 1);
3864-
builder->CreateStore(arr_descr->get_rank(tmp), rank_ptr);
3865-
tmp = arg_struct;
3866-
args.push_back(tmp);
3867-
llvm::Value* dim = builder->CreateAlloca(getIntType(4));
3868-
args.push_back(dim);
3869-
llvm::Value* kind = builder->CreateAlloca(getIntType(4));
3870-
args.push_back(kind);
3871-
} else if( name == "lbound" || name == "ubound" ) {
3809+
if( name == "lbound" || name == "ubound" ) {
38723810
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0].m_value);
38733811
uint32_t h = get_hash((ASR::asr_t*)arg);
38743812
tmp = llvm_symtab[h];
@@ -4374,6 +4312,55 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
43744312
pop_nested_stack(s);
43754313
}
43764314

4315+
void visit_ArraySize(const ASR::ArraySize_t& x) {
4316+
if( x.m_value ) {
4317+
visit_expr_wrapper(x.m_value, true);
4318+
return ;
4319+
}
4320+
visit_expr_wrapper(x.m_v);
4321+
llvm::Value* llvm_arg = tmp;
4322+
llvm::Value* dim_des_val = arr_descr->get_pointer_to_dimension_descriptor_array(llvm_arg);
4323+
if( x.m_dim ) {
4324+
visit_expr_wrapper(x.m_dim, true);
4325+
int kind = ASRUtils::extract_kind_from_ttype_t(ASRUtils::expr_type(x.m_dim));
4326+
tmp = builder->CreateSub(tmp, llvm::ConstantInt::get(context, llvm::APInt(kind * 8, 1)));
4327+
tmp = arr_descr->get_dimension_size(dim_des_val, tmp);
4328+
return ;
4329+
}
4330+
llvm::Value* rank = arr_descr->get_rank(llvm_arg);
4331+
llvm::Value* llvm_size = builder->CreateAlloca(getIntType(ASRUtils::extract_kind_from_ttype_t(x.m_type)), nullptr);
4332+
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), llvm_size);
4333+
4334+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
4335+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
4336+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
4337+
this->current_loophead = loophead;
4338+
this->current_loopend = loopend;
4339+
4340+
llvm::Value* r = builder->CreateAlloca(getIntType(4), nullptr);
4341+
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)), r);
4342+
// head
4343+
start_new_block(loophead);
4344+
llvm::Value *cond = builder->CreateICmpSLT(CreateLoad(r), rank);
4345+
builder->CreateCondBr(cond, loopbody, loopend);
4346+
4347+
// body
4348+
start_new_block(loopbody);
4349+
llvm::Value* r_val = CreateLoad(r);
4350+
llvm::Value* ret_val = CreateLoad(llvm_size);
4351+
llvm::Value* dim_size = arr_descr->get_dimension_size(dim_des_val, r_val);
4352+
ret_val = builder->CreateMul(ret_val, dim_size);
4353+
builder->CreateStore(ret_val, llvm_size);
4354+
r_val = builder->CreateAdd(r_val, llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
4355+
builder->CreateStore(r_val, r);
4356+
builder->CreateBr(loophead);
4357+
4358+
// end
4359+
start_new_block(loopend);
4360+
4361+
tmp = CreateLoad(llvm_size);
4362+
}
4363+
43774364
};
43784365

43794366

src/libasr/pass/array_op.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,10 @@ class ArrayOpVisitor : public PassUtils::PassVisitor<ArrayOpVisitor>
662662
visit_ArrayOpCommon<ASR::BoolOp_t>(x, "_bool_op_res");
663663
}
664664

665+
void visit_ArraySize(const ASR::ArraySize_t& x) {
666+
tmp_val = const_cast<ASR::expr_t*>(&(x.base));
667+
}
668+
665669
void visit_FunctionCall(const ASR::FunctionCall_t& x) {
666670
tmp_val = const_cast<ASR::expr_t*>(&(x.base));
667671
std::string x_name;

0 commit comments

Comments
 (0)