Skip to content

Commit 49a65ea

Browse files
authored
Merge pull request #613 from czgdp1807/structs_02
Initialisation support for ``@dataclass``
2 parents 43dd82b + 13f0170 commit 49a65ea

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

integration_tests/structs_01.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22

33
@dataclass
44
class A:
5-
x: i32
65
y: f32
6+
x: i32
77

88
def f(a: A):
99
print(a.x)
1010
print(a.y)
1111

1212
def g():
1313
x: A
14-
x = A(3, 3.3)
14+
x = A(3.25, 3)
1515
f(x)
16-
# TODO: the above constructor does not initialize `A` in LPython yet, so
17-
# the below does not work:
18-
#assert x.x == 3
19-
#assert x.y == 3.3
16+
assert x.x == 3
17+
assert x.y == 3.25
2018

2119
x.x = 5
2220
x.y = 5.5

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
576576
}
577577

578578
void visit_expr_list(Vec<ASR::call_arg_t>& exprs, size_t n,
579-
Vec<ASR::expr_t*> exprs_vec) {
579+
Vec<ASR::expr_t*>& exprs_vec) {
580580
LFORTRAN_ASSERT(exprs_vec.reserve_called);
581581
for( size_t i = 0; i < n; i++ ) {
582582
exprs_vec.push_back(al, exprs[i].m_value);
@@ -735,6 +735,13 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
735735
Vec<ASR::expr_t*> args_new;
736736
args_new.reserve(al, args.size());
737737
visit_expr_list(args, args.size(), args_new);
738+
ASR::DerivedType_t* derivedtype = ASR::down_cast<ASR::DerivedType_t>(s);
739+
for( size_t i = 0; i < std::min(args.size(), derivedtype->n_members); i++ ) {
740+
std::string member_name = derivedtype->m_members[i];
741+
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
742+
derivedtype->m_symtab->resolve_symbol(member_name));
743+
args_new.p[i] = cast_helper(member_var->m_type, args_new[i], true);
744+
}
738745
ASR::ttype_t* der_type = ASRUtils::TYPE(ASR::make_Derived_t(al, loc, s, nullptr, 0));
739746
return ASR::make_DerivedTypeConstructor_t(al, loc, s, args_new.p, args_new.size(), der_type, nullptr);
740747
} else {

tests/reference/asr-structs_01-be14d49.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
"basename": "asr-structs_01-be14d49",
33
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
44
"infile": "tests/../integration_tests/structs_01.py",
5-
"infile_hash": "db2af95fea66964714bbf3847ba05b2e58a48c8653243acfe39c8d26",
5+
"infile_hash": "b3fd3a07d6aab6b40e16f209a74c4d64448b8a0e509e539f96ebd167",
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-structs_01-be14d49.stdout",
9-
"stdout_hash": "ba8e56eab1d50c57aebc77f4d7a7977c35974520b50b056324764f86",
9+
"stdout_hash": "13f16c8ce74e0249f7800d65538d55bbdb84d06ff6ee7a7887466ed3",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
(TranslationUnit (SymbolTable 1 {A: (DerivedType (SymbolTable 2 {x: (Variable 2 x Local () () Default (Integer 4 []) Source Public Required .false.), y: (Variable 2 y Local () () Default (Real 4 []) Source Public Required .false.)}) A [x y] Source Public ()), _lpython_main_program: (Subroutine (SymbolTable 6 {}) _lpython_main_program [] [(SubroutineCall 1 g () [] ())] Source Public Implementation () .false. .false.), f: (Subroutine (SymbolTable 3 {a: (Variable 3 a In () () Default (Derived 1 A []) Source Public Required .false.)}) f [(Var 3 a)] [(Print () [(DerivedRef (Var 3 a) 2 x (Integer 4 []) ())]) (Print () [(DerivedRef (Var 3 a) 2 y (Real 4 []) ())])] Source Public Implementation () .false. .false.), g: (Subroutine (SymbolTable 4 {x: (Variable 4 x Local () () Default (Derived 1 A []) Source Public Required .false.)}) g [] [(= (Var 4 x) (DerivedTypeConstructor 1 A [] (Derived 1 A []) ()) ()) (SubroutineCall 1 f () [((Var 4 x))] ()) (= (DerivedRef (Var 4 x) 2 x (Integer 4 []) ()) (IntegerConstant 5 (Integer 4 [])) ()) (= (DerivedRef (Var 4 x) 2 y (Real 4 []) ()) (Cast (RealConstant 5.50000000000000000e+00 (Real 8 [])) RealToReal (Real 4 []) (RealConstant 5.50000000000000000e+00 (Real 4 []))) ()) (SubroutineCall 1 f () [((Var 4 x))] ()) (Assert (Compare (DerivedRef (Var 4 x) 2 x (Integer 4 []) ()) Eq (IntegerConstant 5 (Integer 4 [])) (Logical 4 []) () ()) ()) (Assert (Compare (Cast (DerivedRef (Var 4 x) 2 y (Real 4 []) ()) RealToReal (Real 8 []) ()) Eq (RealConstant 5.50000000000000000e+00 (Real 8 [])) (Logical 4 []) () ()) ())] Source Public Implementation () .false. .false.), main_program: (Program (SymbolTable 5 {}) main_program [] [(SubroutineCall 1 _lpython_main_program () [] ())])}) [])
1+
(TranslationUnit (SymbolTable 1 {A: (DerivedType (SymbolTable 2 {x: (Variable 2 x Local () () Default (Integer 4 []) Source Public Required .false.), y: (Variable 2 y Local () () Default (Real 4 []) Source Public Required .false.)}) A [y x] Source Public ()), _lpython_main_program: (Subroutine (SymbolTable 6 {}) _lpython_main_program [] [(SubroutineCall 1 g () [] ())] Source Public Implementation () .false. .false.), f: (Subroutine (SymbolTable 3 {a: (Variable 3 a In () () Default (Derived 1 A []) Source Public Required .false.)}) f [(Var 3 a)] [(Print () [(DerivedRef (Var 3 a) 2 x (Integer 4 []) ())]) (Print () [(DerivedRef (Var 3 a) 2 y (Real 4 []) ())])] Source Public Implementation () .false. .false.), g: (Subroutine (SymbolTable 4 {x: (Variable 4 x Local () () Default (Derived 1 A []) Source Public Required .false.)}) g [] [(= (Var 4 x) (DerivedTypeConstructor 1 A [(Cast (RealConstant 3.25000000000000000e+00 (Real 8 [])) RealToReal (Real 4 []) (RealConstant 3.25000000000000000e+00 (Real 4 []))) (IntegerConstant 3 (Integer 4 []))] (Derived 1 A []) ()) ()) (SubroutineCall 1 f () [((Var 4 x))] ()) (Assert (Compare (DerivedRef (Var 4 x) 2 x (Integer 4 []) ()) Eq (IntegerConstant 3 (Integer 4 [])) (Logical 4 []) () ()) ()) (Assert (Compare (Cast (DerivedRef (Var 4 x) 2 y (Real 4 []) ()) RealToReal (Real 8 []) ()) Eq (RealConstant 3.25000000000000000e+00 (Real 8 [])) (Logical 4 []) () ()) ()) (= (DerivedRef (Var 4 x) 2 x (Integer 4 []) ()) (IntegerConstant 5 (Integer 4 [])) ()) (= (DerivedRef (Var 4 x) 2 y (Real 4 []) ()) (Cast (RealConstant 5.50000000000000000e+00 (Real 8 [])) RealToReal (Real 4 []) (RealConstant 5.50000000000000000e+00 (Real 4 []))) ()) (SubroutineCall 1 f () [((Var 4 x))] ()) (Assert (Compare (DerivedRef (Var 4 x) 2 x (Integer 4 []) ()) Eq (IntegerConstant 5 (Integer 4 [])) (Logical 4 []) () ()) ()) (Assert (Compare (Cast (DerivedRef (Var 4 x) 2 y (Real 4 []) ()) RealToReal (Real 8 []) ()) Eq (RealConstant 5.50000000000000000e+00 (Real 8 [])) (Logical 4 []) () ()) ())] Source Public Implementation () .false. .false.), main_program: (Program (SymbolTable 5 {}) main_program [] [(SubroutineCall 1 _lpython_main_program () [] ())])}) [])

0 commit comments

Comments
 (0)