Skip to content

Commit 8e31c85

Browse files
Merge pull request #250 from ACRIOS-Systems/feature/improve_enum_usage
Improve enum usage in unions.
2 parents 3149c74 + 8fe5169 commit 8e31c85

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

erpcgen/src/PythonGenerator.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,56 @@ data_map PythonGenerator::getTypeInfo(DataType *t)
690690
data_map caseData;
691691
caseData["name"] = unionCase->getCaseName();
692692
caseData["value"] = unionCase->getCaseValue();
693+
if (info.has("discriminatorType") && (info["discriminatorType"].get().get()->getmap()["type"]->getvalue() == "enum"))
694+
{
695+
caseData["type"] = info["discriminatorType"];
696+
}
697+
else if (unionCase->getCaseName() != "")
698+
{
699+
for (auto it : m_globals->getSymbolsOfType(DataType::kEnumTypeSymbol))
700+
{
701+
EnumType *enumType = dynamic_cast<EnumType *>(it);
702+
assert(enumType);
703+
for (auto itMember :enumType->getMembers())
704+
{
705+
if (unionCase->getCaseName() == itMember->getName())
706+
{
707+
data_map typeInfo;
708+
typeInfo["name"] = enumType->getName();
709+
typeInfo["type"] = "enum";
710+
caseData["type"] = typeInfo;
711+
break;
712+
}
713+
}
714+
if(caseData.has("type"))
715+
{
716+
break;
717+
}
718+
}
719+
if(!caseData.has("type"))
720+
{
721+
for (auto it : m_globals->getSymbolsOfType(DataType::kConstSymbol))
722+
{
723+
ConstType *constType = dynamic_cast<ConstType *>(it);
724+
assert(constType);
725+
if (unionCase->getCaseName() == constType->getName())
726+
{
727+
data_map typeInfo;
728+
typeInfo["name"] = "";
729+
typeInfo["type"] = "const";
730+
caseData["type"] = typeInfo;
731+
break;
732+
}
733+
}
734+
}
735+
}
736+
else
737+
{
738+
data_map typeInfo;
739+
typeInfo["name"] = "";
740+
typeInfo["type"] = "";
741+
caseData["type"] = typeInfo;
742+
}
693743
data_list caseMembers;
694744
data_map caseMembersFree;
695745

erpcgen/src/templates/py_coders.template

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ codec.start_write_union({$self}{$info.discriminatorName})
111111
{% set defaultCase = c %}
112112
{% elif not c.isVoid %}
113113
{% set hasNonVoidCase = true %}
114-
{$indent}{$"if" if isFirst else "elif"} {$self}{$info.discriminatorName} == {$c.value}:
114+
{$indent}{$"if" if isFirst else "elif"} {$self}{$info.discriminatorName} == {% if c.name != "" && (c.type.type == "enum" || c.type.type == "const") %}{% if c.type.name != "" %}{$c.type.name}.{% endif %}{$c.name}{% else %}{$c.value}{% endif %}:
115115
{% for cm in c.members %}
116116
{% if cm.isNullable %}
117117
{$indent} if {$name}.{$cm.name} is None:
@@ -171,7 +171,7 @@ codec.start_write_union({$self}{$info.discriminatorName})
171171
{% set defaultCase = c %}
172172
{% elif not c.isVoid %}
173173
{% set hasNonVoidCase = true %}
174-
{$indent}{$"if" if isFirst else "elif"} {$self}{$info.discriminatorName} == {$c.value}:
174+
{$indent}{$"if" if isFirst else "elif"} {$self}{$info.discriminatorName} == {% if c.name != "" && (c.type.type == "enum" || c.type.type == "const") %}{% if c.type.name != "" %}{$c.type.name}.{% endif %}{$c.name}{% else %}{$c.value}{% endif %}:
175175
{% for cm in c.members %}
176176
{% if cm.isNullable %}
177177
{$indent} if not {$codec}.read_null_flag()

erpcgen/test/test_union_py.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,60 +52,60 @@ test/common.py:
5252
- blah = None # int32
5353
- def _read(self, codec)
5454
- discriminator = codec.start_read_union()
55-
- if discriminator == 0
55+
- if discriminator == fruitType.apple
5656
- self.myFoobar = foobar()._read(codec)
57-
- elif discriminator == 2
57+
- elif discriminator == fruitType.banana
5858
- self.x = codec.read_int32()
5959
- self.y = codec.read_float()
60-
- elif discriminator == 5
60+
- elif discriminator == fruitType.papaya
6161
- self.x = codec.read_int32()
6262
- self.y = codec.read_float()
63-
- elif discriminator == 1
63+
- elif discriminator == fruitType.orange
6464
- _n1 = codec.start_read_list()
6565
- self.a = []
6666
- for _i1 in range(_n1)
6767
- _v1 = codec.read_int32()
6868
- self.a.append(_v1)
69-
- elif discriminator == 3
69+
- elif discriminator == fruitType.coconut
7070
- self.cx = codec.read_uint32()
7171
- _n1 = codec.start_read_list()
7272
- self.cc = []
7373
- for _i1 in range(_n1)
7474
- _v1 = codec.read_int32()
7575
- self.cc.append(_v1)
76-
- elif discriminator == 10
76+
- elif discriminator == returnVal
7777
- self.ret = codec.read_int32()
7878
- else # default case
7979
- self.c = codec.read_bool()
8080
- self.blah = codec.read_int32()
8181
- return self, discriminator
8282
- def _write(self, codec, discriminator)
8383
- codec.start_write_union(discriminator)
84-
- if discriminator == 0
84+
- if discriminator == fruitType.apple
8585
- if self.myFoobar is None
8686
- raise ValueError("self.myFoobar is None")
8787
- self.myFoobar._write(codec)
88-
- elif discriminator == 2
88+
- elif discriminator == fruitType.banana
8989
- if self.x is None
9090
- raise ValueError("self.x is None")
9191
- codec.write_int32(self.x)
9292
- if self.y is None
9393
- raise ValueError("self.y is None")
9494
- codec.write_float(self.y)
95-
- elif discriminator == 5
95+
- elif discriminator == fruitType.papaya
9696
- if self.x is None
9797
- raise ValueError("self.x is None")
9898
- codec.write_int32(self.x)
9999
- if self.y is None
100100
- raise ValueError("self.y is None")
101101
- codec.write_float(self.y)
102-
- elif discriminator == 1
102+
- elif discriminator == fruitType.orange
103103
- if self.a is None
104104
- raise ValueError("self.a is None")
105105
- codec.start_write_list(len(self.a))
106106
- for _i1 in self.a
107107
- codec.write_int32(_i1)
108-
- elif discriminator == 3
108+
- elif discriminator == fruitType.coconut
109109
- if self.cx is None
110110
- raise ValueError("self.cx is None")
111111
- codec.write_uint32(self.cx)
@@ -114,7 +114,7 @@ test/common.py:
114114
- codec.start_write_list(len(self.cc))
115115
- for _i1 in self.cc
116116
- codec.write_int32(_i1)
117-
- elif discriminator == 10
117+
- elif discriminator == returnVal
118118
- if self.ret is None
119119
- raise ValueError("self.ret is None")
120120
- codec.write_int32(self.ret)

0 commit comments

Comments
 (0)