Skip to content

Commit 144a50d

Browse files
authored
Merge pull request #1780 from Shaikh-Ubaid/numpy_size_test_case
Fix numpy size() and add test
2 parents 8c71d37 + 5329920 commit 144a50d

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ RUN(NAME variable_decl_03 LABELS cpython llvm c)
281281
RUN(NAME array_expr_01 LABELS cpython llvm c)
282282
RUN(NAME array_expr_02 LABELS cpython llvm c)
283283
RUN(NAME array_size_01 LABELS cpython llvm c)
284+
RUN(NAME array_size_02 LABELS cpython llvm c)
284285
RUN(NAME array_01 LABELS cpython llvm wasm c)
285286
RUN(NAME array_02 LABELS cpython wasm c)
286287
RUN(NAME bindc_01 LABELS cpython llvm c)

integration_tests/array_size_02.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from lpython import i32, f64, c32, c64, u32
2+
from numpy import empty, size
3+
4+
def main0():
5+
x: i32[4, 5, 2] = empty([4, 5, 2])
6+
y: f64[24, 100, 2, 5] = empty([24, 100, 2, 5])
7+
z: i32
8+
w: i32
9+
z = 2
10+
w = 3
11+
print(size(x))
12+
print(size(x, 0))
13+
print(size(x, 1))
14+
print(size(x, 2))
15+
print(size(y))
16+
print(size(y, 0))
17+
print(size(y, 1))
18+
print(size(y, z))
19+
print(size(y, w))
20+
21+
assert size(x) == 40
22+
assert size(x, 0) == 4
23+
assert size(x, 1) == 5
24+
assert size(x, 2) == 2
25+
assert size(y) == 24000
26+
assert size(y, 0) == 24
27+
assert size(y, 1) == 100
28+
assert size(y, z) == 2
29+
assert size(y, w) == 5
30+
31+
def main1():
32+
a: c32[12] = empty([12])
33+
b: c64[15, 15, 10] = empty([15, 15, 10])
34+
c: i32
35+
d: i32
36+
c = 1
37+
d = 2
38+
print(size(a))
39+
print(size(a, 0))
40+
print(size(b))
41+
print(size(b, 0))
42+
print(size(b, c))
43+
print(size(b, d))
44+
45+
assert size(a) == 12
46+
assert size(a, 0) == 12
47+
assert size(b) == 2250
48+
assert size(b, 0) == 15
49+
assert size(b, c) == 15
50+
assert size(b, d) == 10
51+
52+
def main2():
53+
a: i32[2, 3] = empty([2, 3])
54+
print(size(a))
55+
print(size(a, 0))
56+
print(size(a, 1))
57+
58+
assert size(a) == 2*3
59+
assert size(a, 0) == 2
60+
assert size(a, 1) == 3
61+
62+
def main3():
63+
a: u32[2, 3, 4] = empty([2, 3, 4])
64+
b: u64[10, 5] = empty([10, 5])
65+
c: i32
66+
d: i32
67+
c = 1
68+
d = 2
69+
print(size(a))
70+
print(size(a, 0))
71+
print(size(a, c))
72+
print(size(a, d))
73+
74+
print(size(b))
75+
print(size(b, 0))
76+
print(size(b, c))
77+
78+
assert size(a) == 2*3*4
79+
assert size(a, 0) == 2
80+
assert size(a, c) == 3
81+
assert size(a, d) == 4
82+
83+
assert size(b) == 50
84+
assert size(b, 0) == 10
85+
assert size(b, c) == 5
86+
87+
main0()
88+
main1()
89+
main2()
90+
main3()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6749,13 +6749,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
67496749
std::to_string(args.size()) + " arguments instead.",
67506750
x.base.base.loc);
67516751
}
6752+
const Location &loc = x.base.base.loc;
67526753
ASR::expr_t *var = args[0].m_value;
67536754
ASR::expr_t *dim = nullptr;
6755+
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
67546756
if (args.size() == 2) {
6755-
dim = args[1].m_value;
6757+
ASR::expr_t* const_one = ASRUtils::EXPR(make_IntegerConstant_t(al, loc, 1, int_type));
6758+
dim = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc,
6759+
args[1].m_value, ASR::binopType::Add, const_one, int_type, nullptr));
67566760
}
6757-
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4, nullptr, 0));
6758-
tmp = ASR::make_ArraySize_t(al, x.base.base.loc, var, dim, int_type, nullptr);
6761+
tmp = ASR::make_ArraySize_t(al, loc, var, dim, int_type, nullptr);
67596762
return;
67606763
} else if (call_name == "empty") {
67616764
// TODO: check that the `empty` arguments are compatible

0 commit comments

Comments
 (0)