@@ -50,6 +50,9 @@ def __class_getitem__(key):
50
50
51
51
return py_dataclass (arg )
52
52
53
+ def is_ctypes_Structure (obj ):
54
+ return (isclass (obj ) and issubclass (obj , ctypes .Structure ))
55
+
53
56
def is_dataclass (obj ):
54
57
return ((isclass (obj ) and issubclass (obj , ctypes .Structure )) or
55
58
py_is_dataclass (obj ))
@@ -236,6 +239,7 @@ class c_double_complex(c_complex):
236
239
_fields_ = [("real" , ctypes .c_double ), ("imag" , ctypes .c_double )]
237
240
238
241
def convert_type_to_ctype (arg ):
242
+ from enum import Enum
239
243
if arg == f64 :
240
244
return ctypes .c_double
241
245
elif arg == f32 :
@@ -275,6 +279,9 @@ def convert_type_to_ctype(arg):
275
279
return ctypes .POINTER (type )
276
280
elif is_dataclass (arg ):
277
281
return convert_to_ctypes_Structure (arg )
282
+ elif issubclass (arg , Enum ):
283
+ # TODO: store enum in ctypes.Structure with name and value as fields.
284
+ return ctypes .c_int64
278
285
else :
279
286
raise NotImplementedError ("Type %r not implemented" % arg )
280
287
@@ -422,6 +429,7 @@ def __init__(self, *args):
422
429
super ().__init__ (* args )
423
430
424
431
for field , arg in zip (self ._fields_ , args ):
432
+ from enum import Enum
425
433
member = self .__getattribute__ (field [0 ])
426
434
value = arg
427
435
if isinstance (member , ctypes .Array ):
@@ -434,6 +442,8 @@ def __init__(self, *args):
434
442
value = value .flatten ().tolist ()
435
443
value = [c_double_complex (val .real , val .imag ) for val in value ]
436
444
value = type (member )(* value )
445
+ elif isinstance (value , Enum ):
446
+ value = value .value
437
447
self .__setattr__ (field [0 ], value )
438
448
439
449
ctypes_Structure .__name__ = f .__name__
@@ -515,6 +525,7 @@ def __getattr__(self, name: str):
515
525
516
526
def __setattr__ (self , name : str , value ):
517
527
name_ = self .ctypes_ptr .contents .__getattribute__ (name )
528
+ from enum import Enum
518
529
if isinstance (name_ , c_float_complex ):
519
530
if isinstance (value , complex ):
520
531
value = c_float_complex (value .real , value .imag )
@@ -535,6 +546,8 @@ def __setattr__(self, name: str, value):
535
546
value = value .flatten ().tolist ()
536
547
value = [c_double_complex (val .real , val .imag ) for val in value ]
537
548
value = type (name_ )(* value )
549
+ elif isinstance (value , Enum ):
550
+ value = value .value
538
551
self .ctypes_ptr .contents .__setattr__ (name , value )
539
552
540
553
def c_p_pointer (cptr , targettype ):
@@ -545,9 +558,14 @@ def c_p_pointer(cptr, targettype):
545
558
newa = ctypes .cast (cptr , targettype_ptr )
546
559
return newa
547
560
else :
561
+ if py_is_dataclass (targettype ):
562
+ if cptr .value is None :
563
+ return None
564
+ return ctypes .cast (cptr , ctypes .py_object ).value
565
+
548
566
targettype_ptr = ctypes .POINTER (targettype_ptr )
549
567
newa = ctypes .cast (cptr , targettype_ptr )
550
- if is_dataclass (targettype ):
568
+ if is_ctypes_Structure (targettype ):
551
569
# return after wrapping newa inside PointerToStruct
552
570
return PointerToStruct (newa )
553
571
return newa
0 commit comments