Skip to content

Commit 9aefff4

Browse files
authored
Merge pull request #2118 from czgdp1807/array
Accept ``dtype`` argument in ``numpy.array``
2 parents c3314f7 + 4c80fc6 commit 9aefff4

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ RUN(NAME variable_decl_02 LABELS cpython llvm c)
410410
RUN(NAME variable_decl_03 LABELS cpython llvm c)
411411
RUN(NAME array_expr_01 LABELS cpython llvm c)
412412
RUN(NAME array_expr_02 LABELS cpython llvm c NOFAST)
413+
RUN(NAME array_expr_03 LABELS cpython llvm c)
413414
RUN(NAME array_size_01 LABELS cpython llvm c)
414415
RUN(NAME array_size_02 LABELS cpython llvm c)
415416
RUN(NAME array_01 LABELS cpython llvm wasm c)

integration_tests/array_expr_03.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from lpython import i8, i32, dataclass
2+
from numpy import empty, int8, array
3+
4+
5+
@dataclass
6+
class LPBHV_small:
7+
dim: i32 = 4
8+
a: i8[4] = empty(4, dtype=int8)
9+
10+
11+
def g():
12+
l2: LPBHV_small = LPBHV_small(4, array([127, -127, 3, 111], dtype=int8))
13+
14+
print(l2.dim)
15+
assert l2.dim == 4
16+
17+
print(l2.a[0], l2.a[1], l2.a[2], l2.a[3])
18+
assert l2.a[0] == i8(127)
19+
assert l2.a[1] == i8(-127)
20+
assert l2.a[2] == i8(3)
21+
assert l2.a[3] == i8(111)
22+
23+
24+
g()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ namespace CastingUtil {
164164
}
165165
cast_kind = type_rules.at(cast_key);
166166
}
167+
if( ASRUtils::check_equal_type(src, dest, true) ) {
168+
return expr;
169+
}
167170
// TODO: Fix loc
168171
return ASRUtils::EXPR(ASRUtils::make_Cast_t_value(al, loc, expr,
169172
cast_kind, dest));
@@ -505,6 +508,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
505508
// Stores the name of imported functions and the modules they are imported from
506509
std::map<std::string, std::string> imported_functions;
507510

511+
std::map<std::string, std::string> numpy2lpythontypes = {
512+
{"int8", "i8"},
513+
};
514+
508515
CommonVisitor(Allocator &al, LocationManager &lm, SymbolTable *symbol_table,
509516
diag::Diagnostics &diagnostics, bool main_module, std::string module_name,
510517
std::map<int, ASR::symbol_t*> &ast_overload, std::string parent_dir,
@@ -7520,16 +7527,45 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
75207527
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, operand_type, value);
75217528
return;
75227529
} else if( call_name == "array" ) {
7523-
parse_args(x, args);
7530+
ASR::ttype_t* type = nullptr;
7531+
if( x.n_keywords == 0 ) {
7532+
parse_args(x, args);
7533+
} else {
7534+
args.reserve(al, 1);
7535+
visit_expr_list(x.m_args, x.n_args, args);
7536+
if( x.n_keywords > 1 ) {
7537+
throw SemanticError("More than one keyword "
7538+
"arguments aren't recognised by array",
7539+
x.base.base.loc);
7540+
}
7541+
if( std::string(x.m_keywords[0].m_arg) != "dtype" ) {
7542+
throw SemanticError("Unrecognised keyword argument, " +
7543+
std::string(x.m_keywords[0].m_arg), x.base.base.loc);
7544+
}
7545+
std::string dtype_np = "";
7546+
if( AST::is_a<AST::Name_t>(*x.m_keywords[0].m_value) ) {
7547+
AST::Name_t* name_t = AST::down_cast<AST::Name_t>(x.m_keywords[0].m_value);
7548+
dtype_np = name_t->m_id;
7549+
} else {
7550+
LCOMPILERS_ASSERT(false);
7551+
}
7552+
LCOMPILERS_ASSERT(numpy2lpythontypes.find(dtype_np) != numpy2lpythontypes.end());
7553+
Vec<ASR::dimension_t> dims;
7554+
dims.n = 0;
7555+
type = get_type_from_var_annotation(
7556+
numpy2lpythontypes[dtype_np], x.base.base.loc, dims);
7557+
}
75247558
if( args.size() != 1 ) {
75257559
throw SemanticError("array accepts only 1 argument for now, got " +
75267560
std::to_string(args.size()) + " arguments instead.",
75277561
x.base.base.loc);
75287562
}
75297563
ASR::expr_t *arg = args[0].m_value;
7530-
ASR::ttype_t *type = ASRUtils::expr_type(arg);
7564+
if( type == nullptr ) {
7565+
type = ASRUtils::expr_type(arg);
7566+
}
75317567
if(ASR::is_a<ASR::ListConstant_t>(*arg)) {
7532-
type = ASR::down_cast<ASR::List_t>(type)->m_type;
7568+
type = ASRUtils::get_contained_type(type);
75337569
ASR::ListConstant_t* list = ASR::down_cast<ASR::ListConstant_t>(arg);
75347570
ASR::expr_t **m_args = list->m_args;
75357571
size_t n_args = list->n_args;
@@ -7544,6 +7580,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
75447580
dims.push_back(al, dim);
75457581
type = ASRUtils::make_Array_t_util(al, x.base.base.loc, type, dims.p, dims.size(),
75467582
ASR::abiType::Source, false, ASR::array_physical_typeType::PointerToDataArray, true);
7583+
for( size_t i = 0; i < n_args; i++ ) {
7584+
m_args[i] = CastingUtil::perform_casting(m_args[i], ASRUtils::expr_type(m_args[i]),
7585+
ASRUtils::type_get_past_array(type), al, x.base.base.loc);
7586+
}
75477587
tmp = ASR::make_ArrayConstant_t(al, x.base.base.loc, m_args, n_args, type, ASR::arraystorageType::RowMajor);
75487588
} else {
75497589
throw SemanticError("array accepts only list for now, got " +

0 commit comments

Comments
 (0)