Skip to content

Commit c459bb5

Browse files
authored
Merge pull request #2012 from Shaikh-Ubaid/pythoncall_arrays_as_return_type
PythonCall: Support array of simple types as return type
2 parents 62e5ead + 832245c commit c459bb5

File tree

8 files changed

+272
-32
lines changed

8 files changed

+272
-32
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ RUN(NAME bindc_06 LABELS llvm c
533533
EXTRAFILES bindc_06b.c)
534534
RUN(NAME bindpy_01 LABELS cpython c_py ENABLE_CPYTHON NOFAST EXTRAFILES bindpy_01_module.py)
535535
RUN(NAME bindpy_02 LABELS cpython c_py LINK_NUMPY EXTRAFILES bindpy_02_module.py)
536+
RUN(NAME bindpy_03 LABELS cpython c_py LINK_NUMPY NOFAST EXTRAFILES bindpy_03_module.py)
536537
RUN(NAME test_generics_01 LABELS cpython llvm c NOFAST)
537538
RUN(NAME test_cmath LABELS cpython llvm c NOFAST)
538539
RUN(NAME test_complex_01 LABELS cpython llvm c wasm wasm_x64)

integration_tests/bindpy_03.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from lpython import i32, i64, f64, pythoncall, Const, TypeVar
2+
from numpy import empty, int32, int64, float64
3+
4+
n = TypeVar("n")
5+
m = TypeVar("m")
6+
p = TypeVar("p")
7+
q = TypeVar("q")
8+
r = TypeVar("r")
9+
10+
@pythoncall(module = "bindpy_03_module")
11+
def get_cpython_version() -> str:
12+
pass
13+
14+
@pythoncall(module = "bindpy_03_module")
15+
def get_int_array_sum(n: i32, a: i32[:], b: i32[:]) -> i32[n]:
16+
pass
17+
18+
@pythoncall(module = "bindpy_03_module")
19+
def get_int_array_product(n: i32, a: i32[:], b: i32[:]) -> i32[n]:
20+
pass
21+
22+
@pythoncall(module = "bindpy_03_module")
23+
def get_float_array_sum(n: i32, m: i32, a: f64[:], b: f64[:]) -> f64[n, m]:
24+
pass
25+
26+
@pythoncall(module = "bindpy_03_module")
27+
def get_float_array_product(n: i32, m: i32, a: f64[:], b: f64[:]) -> f64[n, m]:
28+
pass
29+
30+
@pythoncall(module = "bindpy_03_module")
31+
def get_array_dot_product(m: i32, a: i64[:], b: f64[:]) -> f64[m]:
32+
pass
33+
34+
@pythoncall(module = "bindpy_03_module")
35+
def get_multidim_array_i64(p: i32, q: i32, r: i32) -> i64[p, q, r]:
36+
pass
37+
38+
# Integers:
39+
def test_array_ints():
40+
n: Const[i32] = 5
41+
a: i32[n] = empty([n], dtype=int32)
42+
b: i32[n] = empty([n], dtype=int32)
43+
44+
i: i32
45+
for i in range(n):
46+
a[i] = i + 10
47+
for i in range(n):
48+
b[i] = i + 20
49+
50+
c: i32[n] = get_int_array_sum(n, a, b)
51+
print(c)
52+
for i in range(n):
53+
assert c[i] == (i + i + 30)
54+
55+
56+
c = get_int_array_product(n, a, b)
57+
print(c)
58+
for i in range(n):
59+
assert c[i] == ((i + 10) * (i + 20))
60+
61+
# Floats
62+
def test_array_floats():
63+
n: Const[i32] = 3
64+
m: Const[i32] = 5
65+
a: f64[n, m] = empty([n, m], dtype=float64)
66+
b: f64[n, m] = empty([n, m], dtype=float64)
67+
68+
i: i32
69+
j: i32
70+
71+
for i in range(n):
72+
for j in range(m):
73+
a[i, j] = f64((i + 10) * (j + 10))
74+
75+
for i in range(n):
76+
for j in range(m):
77+
b[i, j] = f64((i + 20) * (j + 20))
78+
79+
c: f64[n, m] = get_float_array_sum(n, m, a, b)
80+
print(c)
81+
for i in range(n):
82+
for j in range(m):
83+
assert abs(c[i, j] - (f64((i + 10) * (j + 10)) + f64((i + 20) * (j + 20)))) <= 1e-4
84+
85+
c = get_float_array_product(n, m, a, b)
86+
print(c)
87+
for i in range(n):
88+
for j in range(m):
89+
assert abs(c[i, j] - (f64((i + 10) * (j + 10)) * f64((i + 20) * (j + 20)))) <= 1e-4
90+
91+
def test_array_broadcast():
92+
n: Const[i32] = 3
93+
m: Const[i32] = 5
94+
a: i64[n] = empty([n], dtype=int64)
95+
b: f64[n, m] = empty([n, m], dtype=float64)
96+
97+
i: i32
98+
j: i32
99+
for i in range(n):
100+
a[i] = i64(i + 10)
101+
102+
for i in range(n):
103+
for j in range(m):
104+
b[i, j] = f64((i + 1) * (j + 1))
105+
106+
c: f64[m] = get_array_dot_product(m, a, b)
107+
print(c)
108+
assert abs(c[0] - (68.0)) <= 1e-4
109+
assert abs(c[1] - (136.0)) <= 1e-4
110+
assert abs(c[2] - (204.0)) <= 1e-4
111+
assert abs(c[3] - (272.0)) <= 1e-4
112+
assert abs(c[4] - (340.0)) <= 1e-4
113+
114+
def test_multidim_array_return_i64():
115+
p: Const[i32] = 3
116+
q: Const[i32] = 4
117+
r: Const[i32] = 5
118+
a: i64[p, q, r] = empty([p, q, r], dtype=int64)
119+
a = get_multidim_array_i64(p, q, r)
120+
print(a)
121+
122+
i: i32; j: i32; k: i32
123+
for i in range(p):
124+
for j in range(q):
125+
for k in range(r):
126+
assert a[i, j, k] == i64(i * 2 + j * 3 + k * 4)
127+
128+
def main0():
129+
print("CPython version: ", get_cpython_version())
130+
131+
test_array_ints()
132+
test_array_floats()
133+
test_array_broadcast()
134+
test_multidim_array_return_i64()
135+
136+
main0()

integration_tests/bindpy_03_module.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
3+
def get_cpython_version():
4+
import platform
5+
return platform.python_version()
6+
7+
def get_int_array_sum(n, a, b):
8+
return np.add(a, b)
9+
10+
def get_int_array_product(n, a, b):
11+
return np.multiply(a, b)
12+
13+
def get_float_array_sum(n, m, a, b):
14+
return np.add(a, b)
15+
16+
def get_float_array_product(n, m, a, b):
17+
return np.multiply(a, b)
18+
19+
def get_array_dot_product(m, a, b):
20+
print(a, b)
21+
c = a @ b
22+
print(c)
23+
return c
24+
25+
def get_multidim_array_i64(p, q, r):
26+
a = np.empty([p, q, r], dtype = np.int64)
27+
for i in range(p):
28+
for j in range(q):
29+
for k in range(r):
30+
a[i, j, k] = i * 2 + j * 3 + k * 4
31+
return a

src/libasr/asr_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,9 @@ static inline ASR::symbol_t *get_asr_owner(const ASR::expr_t *expr) {
612612
case ASR::exprType::GetPointer: {
613613
return ASRUtils::get_asr_owner(ASR::down_cast<ASR::GetPointer_t>(expr)->m_arg);
614614
}
615+
case ASR::exprType::FunctionCall: {
616+
return ASRUtils::get_asr_owner(ASR::down_cast<ASR::FunctionCall_t>(expr)->m_name);
617+
}
615618
default: {
616619
throw LCompilersException("Cannot find the ASR owner of underlying symbol of expression "
617620
+ std::to_string(expr->type));

src/libasr/codegen/asr_to_c.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,13 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
3030
{
3131
public:
3232

33-
std::string array_types_decls;
34-
3533
std::unique_ptr<CUtils::CUtilFunctions> c_utils_functions;
3634

3735
int counter;
3836

3937
ASRToCVisitor(diag::Diagnostics &diag, CompilerOptions &co,
4038
int64_t default_lower_bound)
4139
: BaseCCPPVisitor(diag, co.platform, co, false, false, true, default_lower_bound),
42-
array_types_decls(std::string("")),
4340
c_utils_functions{std::make_unique<CUtils::CUtilFunctions>()},
4441
counter{0} {
4542
}

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
105105
std::map<uint64_t, std::string> const_var_names;
106106
std::map<int32_t, std::string> gotoid2name;
107107
std::map<std::string, std::string> emit_headers;
108+
std::string array_types_decls;
108109

109110
// Output configuration:
110111
// Use std::string or char*
@@ -146,7 +147,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
146147
BaseCCPPVisitor(diag::Diagnostics &diag, Platform &platform,
147148
CompilerOptions &_compiler_options, bool gen_stdstring, bool gen_stdcomplex, bool is_c,
148149
int64_t default_lower_bound) : diag{diag},
149-
platform{platform}, compiler_options{_compiler_options},
150+
platform{platform}, compiler_options{_compiler_options}, array_types_decls{std::string("")},
150151
gen_stdstring{gen_stdstring}, gen_stdcomplex{gen_stdcomplex},
151152
is_c{is_c}, global_scope{nullptr}, lower_bound{default_lower_bound},
152153
template_number{0}, c_ds_api{std::make_unique<CCPPDSUtils>(is_c, platform)},
@@ -381,28 +382,32 @@ R"(#include <stdio.h>
381382
}
382383
if (x.m_return_var) {
383384
ASR::Variable_t *return_var = ASRUtils::EXPR2VAR(x.m_return_var);
385+
bool is_array = ASRUtils::is_array(return_var->m_type);
384386
if (ASRUtils::is_integer(*return_var->m_type)) {
385-
int kind = ASR::down_cast<ASR::Integer_t>(return_var->m_type)->m_kind;
386-
switch (kind) {
387-
case (1) : sub = "int8_t "; break;
388-
case (2) : sub = "int16_t "; break;
389-
case (4) : sub = "int32_t "; break;
390-
case (8) : sub = "int64_t "; break;
387+
int kind = ASRUtils::extract_kind_from_ttype_t(return_var->m_type);
388+
if (is_array) {
389+
sub = "struct i" + std::to_string(kind * 8) + "* ";
390+
} else {
391+
sub = "int" + std::to_string(kind * 8) + "_t ";
391392
}
392393
} else if (ASRUtils::is_unsigned_integer(*return_var->m_type)) {
393-
int kind = ASR::down_cast<ASR::UnsignedInteger_t>(return_var->m_type)->m_kind;
394-
switch (kind) {
395-
case (1) : sub = "uint8_t "; break;
396-
case (2) : sub = "uint16_t "; break;
397-
case (4) : sub = "uint32_t "; break;
398-
case (8) : sub = "uint64_t "; break;
394+
int kind = ASRUtils::extract_kind_from_ttype_t(return_var->m_type);
395+
if (is_array) {
396+
sub = "struct u" + std::to_string(kind * 8) + "* ";
397+
} else {
398+
sub = "uint" + std::to_string(kind * 8) + "_t ";
399399
}
400400
} else if (ASRUtils::is_real(*return_var->m_type)) {
401-
bool is_float = ASR::down_cast<ASR::Real_t>(return_var->m_type)->m_kind == 4;
402-
if (is_float) {
403-
sub = "float ";
401+
int kind = ASRUtils::extract_kind_from_ttype_t(return_var->m_type);
402+
bool is_float = (kind == 4);
403+
if (is_array) {
404+
sub = "struct r" + std::to_string(kind * 8) + "* ";
404405
} else {
405-
sub = "double ";
406+
if (is_float) {
407+
sub = "float ";
408+
} else {
409+
sub = "double ";
410+
}
406411
}
407412
} else if (ASRUtils::is_logical(*return_var->m_type)) {
408413
sub = "bool ";
@@ -534,17 +539,30 @@ R"(#include <stdio.h>
534539
if (!x.m_return_var) return "";
535540
ASR::Variable_t* r_v = ASRUtils::EXPR2VAR(x.m_return_var);
536541
std::string indent = "\n ";
537-
std::string py_val_cnvrt = CUtils::get_py_obj_return_type_conv_func_from_ttype_t(r_v->m_type) + "(pValue)";
538-
std::string ret_var_decl = indent + CUtils::get_c_type_from_ttype_t(r_v->m_type) + " " + std::string(r_v->m_name) + ";";
539-
std::string ret_assign = indent + std::string(r_v->m_name) + " = " + py_val_cnvrt + ";";
540-
std::string ret_stmt = indent + "return " + std::string(r_v->m_name) + ";";
541-
std::string clear_pValue = indent + "Py_DECREF(pValue);";
542-
std::string copy_result = "";
542+
std::string py_val_cnvrt, ret_var_decl, copy_result;
543543
if (ASRUtils::is_aggregate_type(r_v->m_type)) {
544-
if (ASRUtils::is_character(*r_v->m_type)) {
545-
copy_result = indent + std::string(r_v->m_name) + " = _lfortran_str_copy(" + std::string(r_v->m_name) + ", 1, 0);";
544+
if (ASRUtils::is_array(r_v->m_type)) {
545+
ASR::ttype_t* array_type_asr = ASRUtils::type_get_past_array(r_v->m_type);
546+
std::string array_type_name = CUtils::get_c_type_from_ttype_t(array_type_asr);
547+
std::string array_encoded_type_name = ASRUtils::get_type_code(array_type_asr, true, false);
548+
std::string return_type = c_ds_api->get_array_type(array_type_name, array_encoded_type_name, array_types_decls, true);
549+
py_val_cnvrt = bind_py_utils_functions->get_conv_py_arr_to_c(return_type, array_type_name,
550+
array_encoded_type_name) + "(pValue)";
551+
ret_var_decl = indent + return_type + " _lpython_return_variable;";
552+
} else {
553+
if (ASRUtils::is_character(*r_v->m_type)) {
554+
py_val_cnvrt = CUtils::get_py_obj_return_type_conv_func_from_ttype_t(r_v->m_type) + "(pValue)";
555+
ret_var_decl = indent + CUtils::get_c_type_from_ttype_t(r_v->m_type) + " _lpython_return_variable;";
556+
copy_result = indent + "_lpython_return_variable = _lfortran_str_copy(" + std::string(r_v->m_name) + ", 1, 0);";
557+
}
546558
}
559+
} else {
560+
py_val_cnvrt = CUtils::get_py_obj_return_type_conv_func_from_ttype_t(r_v->m_type) + "(pValue)";
561+
ret_var_decl = indent + CUtils::get_c_type_from_ttype_t(r_v->m_type) + " _lpython_return_variable;";
547562
}
563+
std::string ret_assign = indent + std::string(r_v->m_name) + " = " + py_val_cnvrt + ";";
564+
std::string ret_stmt = indent + "return _lpython_return_variable;";
565+
std::string clear_pValue = indent + "Py_DECREF(pValue);";
548566
return ret_var_decl + ret_assign + copy_result + clear_pValue + ret_stmt + "\n";
549567
}
550568

0 commit comments

Comments
 (0)