Skip to content

Commit 31ce17a

Browse files
authored
Merge pull request #1820 from Shaikh-Ubaid/param_access_in_inout_out
ASR: Support param access in, out, inout
2 parents 8d679fa + 26017ad commit 31ce17a

File tree

9 files changed

+272
-8
lines changed

9 files changed

+272
-8
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ RUN(NAME test_str_comparison LABELS cpython llvm c)
506506
RUN(NAME test_bit_length LABELS cpython llvm c)
507507
RUN(NAME str_to_list_cast LABELS cpython llvm c)
508508
RUN(NAME test_sys_01 LABELS cpython llvm c)
509+
RUN(NAME intent_01 LABELS cpython llvm)
509510

510511

511512
RUN(NAME test_package_01 LABELS cpython llvm)

integration_tests/intent_01.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from lpython import i32, u32, f64, dataclass, In, Out, InOut
2+
3+
@dataclass
4+
class Foo:
5+
p: i32
6+
7+
def f(x: i32, y: In[f64], z: InOut[list[u32]], w: Out[Foo]):
8+
assert (x == -12)
9+
assert abs(y - (4.44)) <= 1e-12
10+
z.append(u32(5))
11+
w.p = 24
12+
13+
14+
def main0():
15+
a: i32 = (-12)
16+
b: f64 = 4.44
17+
c: list[u32] = [u32(1), u32(2), u32(3), u32(4)]
18+
d: Foo = Foo(25)
19+
20+
print(a, b, c, d.p)
21+
22+
f(a, b, c, d)
23+
assert c[-1] == u32(5)
24+
assert d.p == 24
25+
26+
main0()

src/libasr/codegen/llvm_utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ namespace LCompilers {
311311
std::map<std::string, std::map<std::string, int>>& name2memidx) {
312312
switch( asr_type->type ) {
313313
case ASR::ttypeType::Integer:
314+
case ASR::ttypeType::UnsignedInteger:
314315
case ASR::ttypeType::Real:
315316
case ASR::ttypeType::Logical:
316317
case ASR::ttypeType::Complex: {
@@ -2635,7 +2636,7 @@ namespace LCompilers {
26352636
*
26362637
* int i = 0;
26372638
* int j = end_point - 1;
2638-
*
2639+
*
26392640
* tmp;
26402641
*
26412642
* while(j > i) {

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,29 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
15811581
}
15821582
}
15831583

1584+
AST::expr_t* get_var_intent_and_annotation(AST::expr_t *annotation, ASR::intentType &intent) {
1585+
if (AST::is_a<AST::Subscript_t>(*annotation)) {
1586+
AST::Subscript_t *s = AST::down_cast<AST::Subscript_t>(annotation);
1587+
if (AST::is_a<AST::Name_t>(*s->m_value)) {
1588+
std::string ann_name = AST::down_cast<AST::Name_t>(s->m_value)->m_id;
1589+
if (ann_name == "In") {
1590+
intent = ASRUtils::intent_in;
1591+
return s->m_slice;
1592+
} else if (ann_name == "InOut") {
1593+
intent = ASRUtils::intent_inout;
1594+
return s->m_slice;
1595+
} else if (ann_name == "Out") {
1596+
intent = ASRUtils::intent_out;
1597+
return s->m_slice;
1598+
}
1599+
return annotation;
1600+
} else {
1601+
throw SemanticError("Only Name in Subscript supported for now in annotation", annotation->base.loc);
1602+
}
1603+
}
1604+
return annotation;
1605+
}
1606+
15841607
// Convert Python AST type annotation to an ASR type
15851608
// Examples:
15861609
// i32, i64, f32, f64
@@ -3741,7 +3764,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
37413764
if (x.m_args.m_args[i].m_annotation == nullptr) {
37423765
throw SemanticError("Argument does not have a type", loc);
37433766
}
3744-
ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *x.m_args.m_args[i].m_annotation);
3767+
ASR::intentType s_intent = ASRUtils::intent_unspecified;
3768+
AST::expr_t* arg_annotation_type = get_var_intent_and_annotation(x.m_args.m_args[i].m_annotation, s_intent);
3769+
ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *arg_annotation_type);
37453770
// Set the function as generic if an argument is typed with a type parameter
37463771
if (ASRUtils::is_generic(*arg_type)) {
37473772
ASR::ttype_t* arg_type_type = ASRUtils::get_type_parameter(arg_type);
@@ -3766,12 +3791,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
37663791
}
37673792

37683793
std::string arg_s = arg;
3769-
37703794
ASR::expr_t *value = nullptr;
37713795
ASR::expr_t *init_expr = nullptr;
3772-
ASR::intentType s_intent = ASRUtils::intent_in;
3773-
if (ASRUtils::is_array(arg_type)) {
3774-
s_intent = ASRUtils::intent_inout;
3796+
if (s_intent == ASRUtils::intent_unspecified) {
3797+
s_intent = ASRUtils::intent_in;
3798+
if (ASRUtils::is_array(arg_type)) {
3799+
s_intent = ASRUtils::intent_inout;
3800+
}
37753801
}
37763802
ASR::storage_typeType storage_type =
37773803
ASR::storage_typeType::Default;

src/runtime/lpython/lpython.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ def __init__(self, type, dims):
8484
Union = ctypes.Union
8585
Pointer = PointerType("Pointer")
8686

87+
88+
class Intent:
89+
def __init__(self, type):
90+
self._type = type
91+
92+
def __getitem__(self, params):
93+
return params
94+
95+
In = Intent("In")
96+
Out = Intent("Out")
97+
InOut = Intent("InOut")
98+
8799
# Generics
88100

89101
class TypeVar():

tests/intent_01.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@dataclass
2+
class Foo:
3+
p: i32
4+
5+
def f(x: i32, y: In[f64], z: InOut[list[u32]], w: Out[Foo[5]]):
6+
pass
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-intent_01-66824bc",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/intent_01.py",
5+
"infile_hash": "0a8fa2940567fccee2cfccd4af40f353b74dbe542590460a195246a5",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": "asr-intent_01-66824bc.stdout",
9+
"stdout_hash": "6c217775c0f43212356588d01124266dfe417ce0fd72c63c8cec30ad",
10+
"stderr": null,
11+
"stderr_hash": null,
12+
"returncode": 0
13+
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
(TranslationUnit
2+
(SymbolTable
3+
1
4+
{
5+
_global_symbols:
6+
(Module
7+
(SymbolTable
8+
5
9+
{
10+
Foo:
11+
(StructType
12+
(SymbolTable
13+
2
14+
{
15+
p:
16+
(Variable
17+
2
18+
p
19+
[]
20+
Local
21+
()
22+
()
23+
Default
24+
(Integer 4 [])
25+
Source
26+
Public
27+
Required
28+
.false.
29+
)
30+
})
31+
Foo
32+
[]
33+
[p]
34+
Source
35+
Public
36+
.false.
37+
.false.
38+
()
39+
()
40+
),
41+
f:
42+
(Function
43+
(SymbolTable
44+
3
45+
{
46+
w:
47+
(Variable
48+
3
49+
w
50+
[]
51+
Out
52+
()
53+
()
54+
Default
55+
(Struct
56+
5 Foo
57+
[((IntegerConstant 0 (Integer 4 []))
58+
(IntegerConstant 5 (Integer 4 [])))]
59+
)
60+
Source
61+
Public
62+
Required
63+
.false.
64+
),
65+
x:
66+
(Variable
67+
3
68+
x
69+
[]
70+
In
71+
()
72+
()
73+
Default
74+
(Integer 4 [])
75+
Source
76+
Public
77+
Required
78+
.false.
79+
),
80+
y:
81+
(Variable
82+
3
83+
y
84+
[]
85+
In
86+
()
87+
()
88+
Default
89+
(Real 8 [])
90+
Source
91+
Public
92+
Required
93+
.false.
94+
),
95+
z:
96+
(Variable
97+
3
98+
z
99+
[]
100+
InOut
101+
()
102+
()
103+
Default
104+
(List
105+
(UnsignedInteger
106+
4
107+
[]
108+
)
109+
)
110+
Source
111+
Public
112+
Required
113+
.false.
114+
)
115+
})
116+
f
117+
(FunctionType
118+
[(Integer 4 [])
119+
(Real 8 [])
120+
(List
121+
(UnsignedInteger
122+
4
123+
[]
124+
)
125+
)
126+
(Struct
127+
5 Foo
128+
[((IntegerConstant 0 (Integer 4 []))
129+
(IntegerConstant 5 (Integer 4 [])))]
130+
)]
131+
()
132+
Source
133+
Implementation
134+
()
135+
.false.
136+
.false.
137+
.false.
138+
.false.
139+
.false.
140+
[]
141+
[]
142+
.false.
143+
)
144+
[]
145+
[(Var 3 x)
146+
(Var 3 y)
147+
(Var 3 z)
148+
(Var 3 w)]
149+
[]
150+
()
151+
Public
152+
.false.
153+
.false.
154+
()
155+
)
156+
})
157+
_global_symbols
158+
[]
159+
.false.
160+
.false.
161+
),
162+
main_program:
163+
(Program
164+
(SymbolTable
165+
4
166+
{
167+
168+
})
169+
main_program
170+
[]
171+
[]
172+
)
173+
})
174+
[]
175+
)

tests/tests.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,6 @@ c = true
291291
filename = "test_end_sep_keywords.py"
292292
asr = true
293293

294-
# integration_tests
295-
296294
[[test]]
297295
filename = "lpython1.py"
298296
llvm = true
@@ -301,6 +299,12 @@ llvm = true
301299
filename = "print_str.py"
302300
wat = true
303301

302+
[[test]]
303+
filename = "intent_01.py"
304+
asr = true
305+
306+
# integration_tests
307+
304308
[[test]]
305309
filename = "../integration_tests/test_builtin.py"
306310
asr = true

0 commit comments

Comments
 (0)