Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ RUN(NAME bindc_01 LABELS cpython llvm c)
RUN(NAME bindc_02 LABELS cpython llvm c)
RUN(NAME bindc_04 LABELS llvm c)
RUN(NAME bindc_07 LABELS cpython llvm c)
RUN(NAME bindc_09 LABELS cpython llvm c)
RUN(NAME exit_01 LABELS cpython llvm c)
RUN(NAME exit_02 FAIL LABELS cpython llvm c)
RUN(NAME exit_03 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
43 changes: 43 additions & 0 deletions integration_tests/bindc_09.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from enum import Enum

from lpython import CPtr, c_p_pointer, p_c_pointer, dataclass, empty_c_void_p, pointer, Pointer, i32, ccallable

class Value(Enum):
TEN: i32 = 10
TWO: i32 = 2
ONE: i32 = 1
FIVE: i32 = 5

@dataclass
class Foo:
value: Value

@ccallable
@dataclass
class FooC:
value: Value

def bar(foo_ptr: CPtr) -> None:
foo: Pointer[Foo] = c_p_pointer(foo_ptr, Foo)
foo.value = Value.FIVE

def barc(foo_ptr: CPtr) -> None:
foo: Pointer[FooC] = c_p_pointer(foo_ptr, FooC)
foo.value = Value.ONE

def main() -> None:
foo: Foo = Foo(Value.TEN)
fooc: FooC = FooC(Value.TWO)
foo_ptr: CPtr = empty_c_void_p()

p_c_pointer(pointer(foo), foo_ptr)
bar(foo_ptr)
print(foo.value, foo.value.name)
assert foo.value == Value.FIVE

p_c_pointer(pointer(fooc), foo_ptr)
barc(foo_ptr)
print(fooc.value)
assert fooc.value == Value.ONE.value

main()
3 changes: 2 additions & 1 deletion integration_tests/structs_15.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lpython import i32, i16, i8, i64, CPtr, dataclass, ccall, Pointer, c_p_pointer, sizeof
from lpython import i32, i16, i8, CPtr, dataclass, ccall, Pointer, c_p_pointer, sizeof, ccallable

@ccallable
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that any dataclass which is being used in C code as well should be decorated with ccallable and lpython.py should raise an error in case we try to use a pure Python dataclass with C-APIs. What do you say @certik?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so.

@dataclass
class A:
x: i16
Expand Down
20 changes: 19 additions & 1 deletion src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def dataclass(arg):
arg.__class_getitem__ = lambda self: None
return py_dataclass(arg)

def is_ctypes_Structure(obj):
return (isclass(obj) and issubclass(obj, ctypes.Structure))

def is_dataclass(obj):
return ((isclass(obj) and issubclass(obj, ctypes.Structure)) or
py_is_dataclass(obj))
Expand Down Expand Up @@ -221,6 +224,7 @@ class c_double_complex(c_complex):
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]

def convert_type_to_ctype(arg):
from enum import Enum
if arg == f64:
return ctypes.c_double
elif arg == f32:
Expand Down Expand Up @@ -258,6 +262,9 @@ def convert_type_to_ctype(arg):
return ctypes.POINTER(type)
elif is_dataclass(arg):
return convert_to_ctypes_Structure(arg)
elif issubclass(arg, Enum):
# TODO: store enum in ctypes.Structure with name and value as fields.
return ctypes.c_int64
else:
raise NotImplementedError("Type %r not implemented" % arg)

Expand Down Expand Up @@ -405,6 +412,7 @@ def __init__(self, *args):
super().__init__(*args)

for field, arg in zip(self._fields_, args):
from enum import Enum
member = self.__getattribute__(field[0])
value = arg
if isinstance(member, ctypes.Array):
Expand All @@ -417,6 +425,8 @@ def __init__(self, *args):
value = value.flatten().tolist()
value = [c_double_complex(val.real, val.imag) for val in value]
value = type(member)(*value)
elif isinstance(value, Enum):
value = value.value
self.__setattr__(field[0], value)

ctypes_Structure.__name__ = f.__name__
Expand Down Expand Up @@ -498,6 +508,7 @@ def __getattr__(self, name: str):

def __setattr__(self, name: str, value):
name_ = self.ctypes_ptr.contents.__getattribute__(name)
from enum import Enum
if isinstance(name_, c_float_complex):
if isinstance(value, complex):
value = c_float_complex(value.real, value.imag)
Expand All @@ -518,6 +529,8 @@ def __setattr__(self, name: str, value):
value = value.flatten().tolist()
value = [c_double_complex(val.real, val.imag) for val in value]
value = type(name_)(*value)
elif isinstance(value, Enum):
value = value.value
self.ctypes_ptr.contents.__setattr__(name, value)

def c_p_pointer(cptr, targettype):
Expand All @@ -526,9 +539,14 @@ def c_p_pointer(cptr, targettype):
newa = ctypes.cast(cptr, targettype_ptr)
return newa
else:
if py_is_dataclass(targettype):
if cptr.value is None:
return None
return ctypes.cast(cptr, ctypes.py_object).value

targettype_ptr = ctypes.POINTER(targettype_ptr)
newa = ctypes.cast(cptr, targettype_ptr)
if is_dataclass(targettype):
if is_ctypes_Structure(targettype):
# return after wrapping newa inside PointerToStruct
return PointerToStruct(newa)
return newa
Expand Down