Skip to content

Commit 04bf030

Browse files
authored
Merge pull request #1823 from czgdp1807/enum_struct
Support structs with enum fields in ``c_p_pointer``
2 parents 1aad65e + acb5499 commit 04bf030

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ RUN(NAME bindc_02 LABELS cpython llvm c)
289289
RUN(NAME bindc_04 LABELS llvm c)
290290
RUN(NAME bindc_07 LABELS cpython llvm c)
291291
RUN(NAME bindc_08 LABELS cpython llvm c)
292+
RUN(NAME bindc_09 LABELS cpython llvm c)
292293
RUN(NAME exit_01 LABELS cpython llvm c)
293294
RUN(NAME exit_02 FAIL LABELS cpython llvm c)
294295
RUN(NAME exit_03 LABELS cpython llvm c wasm wasm_x86 wasm_x64)

integration_tests/bindc_09.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from enum import Enum
2+
3+
from lpython import CPtr, c_p_pointer, p_c_pointer, dataclass, empty_c_void_p, pointer, Pointer, i32, ccallable
4+
5+
class Value(Enum):
6+
TEN: i32 = 10
7+
TWO: i32 = 2
8+
ONE: i32 = 1
9+
FIVE: i32 = 5
10+
11+
@dataclass
12+
class Foo:
13+
value: Value
14+
15+
@ccallable
16+
@dataclass
17+
class FooC:
18+
value: Value
19+
20+
def bar(foo_ptr: CPtr) -> None:
21+
foo: Pointer[Foo] = c_p_pointer(foo_ptr, Foo)
22+
foo.value = Value.FIVE
23+
24+
def barc(foo_ptr: CPtr) -> None:
25+
foo: Pointer[FooC] = c_p_pointer(foo_ptr, FooC)
26+
foo.value = Value.ONE
27+
28+
def main() -> None:
29+
foo: Foo = Foo(Value.TEN)
30+
fooc: FooC = FooC(Value.TWO)
31+
foo_ptr: CPtr = empty_c_void_p()
32+
33+
p_c_pointer(pointer(foo), foo_ptr)
34+
bar(foo_ptr)
35+
print(foo.value, foo.value.name)
36+
assert foo.value == Value.FIVE
37+
38+
p_c_pointer(pointer(fooc), foo_ptr)
39+
barc(foo_ptr)
40+
print(fooc.value)
41+
assert fooc.value == Value.ONE.value
42+
43+
main()

integration_tests/structs_15.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from lpython import i32, i16, i8, i64, CPtr, dataclass, ccall, Pointer, c_p_pointer, sizeof
1+
from lpython import i32, i16, i8, CPtr, dataclass, ccall, Pointer, c_p_pointer, sizeof, ccallable
22

3+
@ccallable
34
@dataclass
45
class A:
56
x: i16

src/runtime/lpython/lpython.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def __class_getitem__(key):
5050

5151
return py_dataclass(arg)
5252

53+
def is_ctypes_Structure(obj):
54+
return (isclass(obj) and issubclass(obj, ctypes.Structure))
55+
5356
def is_dataclass(obj):
5457
return ((isclass(obj) and issubclass(obj, ctypes.Structure)) or
5558
py_is_dataclass(obj))
@@ -236,6 +239,7 @@ class c_double_complex(c_complex):
236239
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
237240

238241
def convert_type_to_ctype(arg):
242+
from enum import Enum
239243
if arg == f64:
240244
return ctypes.c_double
241245
elif arg == f32:
@@ -275,6 +279,9 @@ def convert_type_to_ctype(arg):
275279
return ctypes.POINTER(type)
276280
elif is_dataclass(arg):
277281
return convert_to_ctypes_Structure(arg)
282+
elif issubclass(arg, Enum):
283+
# TODO: store enum in ctypes.Structure with name and value as fields.
284+
return ctypes.c_int64
278285
else:
279286
raise NotImplementedError("Type %r not implemented" % arg)
280287

@@ -422,6 +429,7 @@ def __init__(self, *args):
422429
super().__init__(*args)
423430

424431
for field, arg in zip(self._fields_, args):
432+
from enum import Enum
425433
member = self.__getattribute__(field[0])
426434
value = arg
427435
if isinstance(member, ctypes.Array):
@@ -434,6 +442,8 @@ def __init__(self, *args):
434442
value = value.flatten().tolist()
435443
value = [c_double_complex(val.real, val.imag) for val in value]
436444
value = type(member)(*value)
445+
elif isinstance(value, Enum):
446+
value = value.value
437447
self.__setattr__(field[0], value)
438448

439449
ctypes_Structure.__name__ = f.__name__
@@ -515,6 +525,7 @@ def __getattr__(self, name: str):
515525

516526
def __setattr__(self, name: str, value):
517527
name_ = self.ctypes_ptr.contents.__getattribute__(name)
528+
from enum import Enum
518529
if isinstance(name_, c_float_complex):
519530
if isinstance(value, complex):
520531
value = c_float_complex(value.real, value.imag)
@@ -535,6 +546,8 @@ def __setattr__(self, name: str, value):
535546
value = value.flatten().tolist()
536547
value = [c_double_complex(val.real, val.imag) for val in value]
537548
value = type(name_)(*value)
549+
elif isinstance(value, Enum):
550+
value = value.value
538551
self.ctypes_ptr.contents.__setattr__(name, value)
539552

540553
def c_p_pointer(cptr, targettype):
@@ -545,9 +558,14 @@ def c_p_pointer(cptr, targettype):
545558
newa = ctypes.cast(cptr, targettype_ptr)
546559
return newa
547560
else:
561+
if py_is_dataclass(targettype):
562+
if cptr.value is None:
563+
return None
564+
return ctypes.cast(cptr, ctypes.py_object).value
565+
548566
targettype_ptr = ctypes.POINTER(targettype_ptr)
549567
newa = ctypes.cast(cptr, targettype_ptr)
550-
if is_dataclass(targettype):
568+
if is_ctypes_Structure(targettype):
551569
# return after wrapping newa inside PointerToStruct
552570
return PointerToStruct(newa)
553571
return newa

0 commit comments

Comments
 (0)