Skip to content

Commit bc7c483

Browse files
committed
ASR: Support named arguments in struct initialization
1 parent 8c287b4 commit bc7c483

File tree

1 file changed

+68
-1
lines changed

1 file changed

+68
-1
lines changed

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,63 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
893893
return true;
894894
}
895895

896+
int64_t find_argument_position_from_name(ASR::StructType_t* orig_struct, std::string arg_name) {
897+
int64_t arg_position = -1;
898+
for( size_t i = 0; i < orig_struct->n_members; i++ ) {
899+
std::string original_arg_name = std::string(orig_struct->m_members[i]);
900+
if( original_arg_name == arg_name ) {
901+
return i;
902+
}
903+
}
904+
return arg_position;
905+
}
906+
907+
void visit_expr_list(AST::expr_t** pos_args, size_t n_pos_args,
908+
AST::keyword_t* kwargs, size_t n_kwargs,
909+
Vec<ASR::call_arg_t>& call_args_vec,
910+
ASR::StructType_t* orig_struct, const Location &loc) {
911+
LCOMPILERS_ASSERT(call_args_vec.reserve_called);
912+
913+
// Fill the whole call_args_vec with nullptr
914+
// This is for error handling later on.
915+
for( size_t i = 0; i < n_pos_args + n_kwargs; i++ ) {
916+
ASR::call_arg_t call_arg;
917+
Location loc;
918+
loc.first = loc.last = 1;
919+
call_arg.m_value = nullptr;
920+
call_arg.loc = loc;
921+
call_args_vec.push_back(al, call_arg);
922+
}
923+
924+
// Now handle positional arguments in the following loop
925+
for( size_t i = 0; i < n_pos_args; i++ ) {
926+
this->visit_expr(*pos_args[i]);
927+
ASR::expr_t* expr = ASRUtils::EXPR(tmp);
928+
call_args_vec.p[i].loc = expr->base.loc;
929+
call_args_vec.p[i].m_value = expr;
930+
}
931+
932+
// Now handle keyword arguments in the following loop
933+
for( size_t i = 0; i < n_kwargs; i++ ) {
934+
this->visit_expr(*kwargs[i].m_value);
935+
ASR::expr_t* expr = ASRUtils::EXPR(tmp);
936+
std::string arg_name = std::string(kwargs[i].m_arg);
937+
int64_t arg_pos = find_argument_position_from_name(orig_struct, arg_name);
938+
if( arg_pos == -1 ) {
939+
throw SemanticError("Member '" + arg_name + "' not found in struct", kwargs[i].loc);
940+
} else if (arg_pos >= (int64_t)call_args_vec.size()) {
941+
throw SemanticError("Not enough arguments to " + std::string(orig_struct->m_name)
942+
+ "(), expected " + std::to_string(orig_struct->n_members), loc);
943+
}
944+
if( call_args_vec[arg_pos].m_value != nullptr ) {
945+
throw SemanticError(std::string(orig_struct->m_name) + "() got multiple values for argument '"
946+
+ arg_name + "'", kwargs[i].loc);
947+
}
948+
call_args_vec.p[arg_pos].loc = expr->base.loc;
949+
call_args_vec.p[arg_pos].m_value = expr;
950+
}
951+
}
952+
896953
void visit_expr_list_with_cast(ASR::expr_t** m_args, size_t n_args,
897954
Vec<ASR::call_arg_t>& call_args_vec,
898955
Vec<ASR::call_arg_t>& args,
@@ -1195,7 +1252,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
11951252
}
11961253
} else if(ASR::is_a<ASR::StructType_t>(*s)) {
11971254
ASR::StructType_t* StructType = ASR::down_cast<ASR::StructType_t>(s);
1198-
for( size_t i = 0; i < std::min(args.size(), StructType->n_members); i++ ) {
1255+
if (n_kwargs > 0) {
1256+
args.reserve(al, n_pos_args + n_kwargs);
1257+
visit_expr_list(pos_args, n_pos_args, kwargs, n_kwargs,
1258+
args, StructType, loc);
1259+
}
1260+
1261+
if (args.size() > 0 && args.size() != StructType->n_members) {
1262+
throw SemanticError("StructConstructor arguments do not match the number of struct members", loc);
1263+
}
1264+
1265+
for( size_t i = 0; i < args.size(); i++ ) {
11991266
std::string member_name = StructType->m_members[i];
12001267
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
12011268
StructType->m_symtab->resolve_symbol(member_name));

0 commit comments

Comments
 (0)