diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index c918c9ae76..9f7fd1aee9 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -506,6 +506,7 @@ RUN(NAME test_str_comparison LABELS cpython llvm c) RUN(NAME test_bit_length LABELS cpython llvm c) RUN(NAME str_to_list_cast LABELS cpython llvm c) RUN(NAME test_sys_01 LABELS cpython llvm c) +RUN(NAME intent_01 LABELS cpython llvm) RUN(NAME test_package_01 LABELS cpython llvm) diff --git a/integration_tests/intent_01.py b/integration_tests/intent_01.py new file mode 100644 index 0000000000..e4f679b05c --- /dev/null +++ b/integration_tests/intent_01.py @@ -0,0 +1,26 @@ +from lpython import i32, u32, f64, dataclass, In, Out, InOut + +@dataclass +class Foo: + p: i32 + +def f(x: i32, y: In[f64], z: InOut[list[u32]], w: Out[Foo]): + assert (x == -12) + assert abs(y - (4.44)) <= 1e-12 + z.append(u32(5)) + w.p = 24 + + +def main0(): + a: i32 = (-12) + b: f64 = 4.44 + c: list[u32] = [u32(1), u32(2), u32(3), u32(4)] + d: Foo = Foo(25) + + print(a, b, c, d.p) + + f(a, b, c, d) + assert c[-1] == u32(5) + assert d.p == 24 + +main0() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index cee2dac475..8638412e3b 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -311,6 +311,7 @@ namespace LCompilers { std::map>& name2memidx) { switch( asr_type->type ) { case ASR::ttypeType::Integer: + case ASR::ttypeType::UnsignedInteger: case ASR::ttypeType::Real: case ASR::ttypeType::Logical: case ASR::ttypeType::Complex: { @@ -2635,7 +2636,7 @@ namespace LCompilers { * * int i = 0; * int j = end_point - 1; - * + * * tmp; * * while(j > i) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 8e9b4a6db3..2c90d9788e 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1581,6 +1581,29 @@ class CommonVisitor : public AST::BaseVisitor { } } + AST::expr_t* get_var_intent_and_annotation(AST::expr_t *annotation, ASR::intentType &intent) { + if (AST::is_a(*annotation)) { + AST::Subscript_t *s = AST::down_cast(annotation); + if (AST::is_a(*s->m_value)) { + std::string ann_name = AST::down_cast(s->m_value)->m_id; + if (ann_name == "In") { + intent = ASRUtils::intent_in; + return s->m_slice; + } else if (ann_name == "InOut") { + intent = ASRUtils::intent_inout; + return s->m_slice; + } else if (ann_name == "Out") { + intent = ASRUtils::intent_out; + return s->m_slice; + } + return annotation; + } else { + throw SemanticError("Only Name in Subscript supported for now in annotation", annotation->base.loc); + } + } + return annotation; + } + // Convert Python AST type annotation to an ASR type // Examples: // i32, i64, f32, f64 @@ -3741,7 +3764,9 @@ class SymbolTableVisitor : public CommonVisitor { if (x.m_args.m_args[i].m_annotation == nullptr) { throw SemanticError("Argument does not have a type", loc); } - ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *x.m_args.m_args[i].m_annotation); + ASR::intentType s_intent = ASRUtils::intent_unspecified; + AST::expr_t* arg_annotation_type = get_var_intent_and_annotation(x.m_args.m_args[i].m_annotation, s_intent); + ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *arg_annotation_type); // Set the function as generic if an argument is typed with a type parameter if (ASRUtils::is_generic(*arg_type)) { ASR::ttype_t* arg_type_type = ASRUtils::get_type_parameter(arg_type); @@ -3766,12 +3791,13 @@ class SymbolTableVisitor : public CommonVisitor { } std::string arg_s = arg; - ASR::expr_t *value = nullptr; ASR::expr_t *init_expr = nullptr; - ASR::intentType s_intent = ASRUtils::intent_in; - if (ASRUtils::is_array(arg_type)) { - s_intent = ASRUtils::intent_inout; + if (s_intent == ASRUtils::intent_unspecified) { + s_intent = ASRUtils::intent_in; + if (ASRUtils::is_array(arg_type)) { + s_intent = ASRUtils::intent_inout; + } } ASR::storage_typeType storage_type = ASR::storage_typeType::Default; diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 65684ce5e6..1583bf7641 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -84,6 +84,18 @@ def __init__(self, type, dims): Union = ctypes.Union Pointer = PointerType("Pointer") + +class Intent: + def __init__(self, type): + self._type = type + + def __getitem__(self, params): + return params + +In = Intent("In") +Out = Intent("Out") +InOut = Intent("InOut") + # Generics class TypeVar(): diff --git a/tests/intent_01.py b/tests/intent_01.py new file mode 100644 index 0000000000..b3552079c0 --- /dev/null +++ b/tests/intent_01.py @@ -0,0 +1,6 @@ +@dataclass +class Foo: + p: i32 + +def f(x: i32, y: In[f64], z: InOut[list[u32]], w: Out[Foo[5]]): + pass diff --git a/tests/reference/asr-intent_01-66824bc.json b/tests/reference/asr-intent_01-66824bc.json new file mode 100644 index 0000000000..8fdc0dad55 --- /dev/null +++ b/tests/reference/asr-intent_01-66824bc.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-intent_01-66824bc", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/intent_01.py", + "infile_hash": "0a8fa2940567fccee2cfccd4af40f353b74dbe542590460a195246a5", + "outfile": null, + "outfile_hash": null, + "stdout": "asr-intent_01-66824bc.stdout", + "stdout_hash": "6c217775c0f43212356588d01124266dfe417ce0fd72c63c8cec30ad", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/asr-intent_01-66824bc.stdout b/tests/reference/asr-intent_01-66824bc.stdout new file mode 100644 index 0000000000..59a0782219 --- /dev/null +++ b/tests/reference/asr-intent_01-66824bc.stdout @@ -0,0 +1,175 @@ +(TranslationUnit + (SymbolTable + 1 + { + _global_symbols: + (Module + (SymbolTable + 5 + { + Foo: + (StructType + (SymbolTable + 2 + { + p: + (Variable + 2 + p + [] + Local + () + () + Default + (Integer 4 []) + Source + Public + Required + .false. + ) + }) + Foo + [] + [p] + Source + Public + .false. + .false. + () + () + ), + f: + (Function + (SymbolTable + 3 + { + w: + (Variable + 3 + w + [] + Out + () + () + Default + (Struct + 5 Foo + [((IntegerConstant 0 (Integer 4 [])) + (IntegerConstant 5 (Integer 4 [])))] + ) + Source + Public + Required + .false. + ), + x: + (Variable + 3 + x + [] + In + () + () + Default + (Integer 4 []) + Source + Public + Required + .false. + ), + y: + (Variable + 3 + y + [] + In + () + () + Default + (Real 8 []) + Source + Public + Required + .false. + ), + z: + (Variable + 3 + z + [] + InOut + () + () + Default + (List + (UnsignedInteger + 4 + [] + ) + ) + Source + Public + Required + .false. + ) + }) + f + (FunctionType + [(Integer 4 []) + (Real 8 []) + (List + (UnsignedInteger + 4 + [] + ) + ) + (Struct + 5 Foo + [((IntegerConstant 0 (Integer 4 [])) + (IntegerConstant 5 (Integer 4 [])))] + )] + () + Source + Implementation + () + .false. + .false. + .false. + .false. + .false. + [] + [] + .false. + ) + [] + [(Var 3 x) + (Var 3 y) + (Var 3 z) + (Var 3 w)] + [] + () + Public + .false. + .false. + () + ) + }) + _global_symbols + [] + .false. + .false. + ), + main_program: + (Program + (SymbolTable + 4 + { + + }) + main_program + [] + [] + ) + }) + [] +) diff --git a/tests/tests.toml b/tests/tests.toml index cfb980ae44..9e3e82f142 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -291,8 +291,6 @@ c = true filename = "test_end_sep_keywords.py" asr = true -# integration_tests - [[test]] filename = "lpython1.py" llvm = true @@ -301,6 +299,12 @@ llvm = true filename = "print_str.py" wat = true +[[test]] +filename = "intent_01.py" +asr = true + +# integration_tests + [[test]] filename = "../integration_tests/test_builtin.py" asr = true