Skip to content

Commit 8ea663f

Browse files
committed
Fix recursive deserialization of cbor bytes
1 parent 68f26ee commit 8ea663f

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

pycardano/serialization.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import typing
6+
57
import re
68
from collections import OrderedDict, UserList, defaultdict
79
from copy import deepcopy
@@ -413,65 +415,67 @@ def _restore_dataclass_field(
413415
Returns:
414416
Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable.
415417
"""
418+
416419
if "object_hook" in f.metadata:
417420
return f.metadata["object_hook"](v)
418-
elif isclass(f.type) and issubclass(f.type, CBORSerializable):
419-
return f.type.from_primitive(v)
420-
elif hasattr(f.type, "__origin__") and (f.type.__origin__ is list):
421-
t_args = f.type.__args__
421+
return _restore_typed_primitive(f.type, v)
422+
423+
424+
def _restore_typed_primitive(
425+
t: typing.Type, v: Primitive
426+
) -> Union[Primitive, CBORSerializable]:
427+
"""Try to restore a value back to its original type based on information given in field.
428+
429+
Args:
430+
f (type): A type
431+
v (:const:`Primitive`): A CBOR primitive.
432+
433+
Returns:
434+
Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable.
435+
"""
436+
if t in PRIMITIVE_TYPES and isinstance(v, t):
437+
return v
438+
elif isclass(t) and issubclass(t, CBORSerializable):
439+
return t.from_primitive(v)
440+
elif hasattr(t, "__origin__") and (t.__origin__ is list):
441+
t_args = t.__args__
422442
if len(t_args) != 1:
423443
raise DeserializeException(
424444
f"List types need exactly one type argument, but got {t_args}"
425445
)
426446
t = t_args[0]
427447
if not isinstance(v, list):
428448
raise DeserializeException(f"Expected type list but got {type(v)}")
429-
if isclass(t) and issubclass(t, CBORSerializable):
430-
return IndefiniteList([t.from_primitive(w) for w in v])
431-
else:
432-
return IndefiniteList(v)
433-
elif isclass(f.type) and issubclass(f.type, IndefiniteList):
449+
return IndefiniteList([_restore_typed_primitive(t, w) for w in v])
450+
elif isclass(t) and issubclass(t, IndefiniteList):
434451
return IndefiniteList(v)
435-
elif hasattr(f.type, "__origin__") and (f.type.__origin__ is dict):
436-
t_args = f.type.__args__
452+
elif hasattr(t, "__origin__") and (t.__origin__ is dict):
453+
t_args = t.__args__
437454
if len(t_args) != 2:
438455
raise DeserializeException(
439456
f"Dict types need exactly two type arguments, but got {t_args}"
440457
)
441458
key_t = t_args[0]
442459
val_t = t_args[1]
443-
if isclass(key_t) and issubclass(key_t, CBORSerializable):
444-
key_converter = key_t.from_primitive
445-
else:
446-
key_converter = _identity
447-
if isclass(val_t) and issubclass(val_t, CBORSerializable):
448-
val_converter = val_t.from_primitive
449-
else:
450-
val_converter = _identity
451460
if not isinstance(v, dict):
452461
raise DeserializeException(f"Expected dict type but got {type(v)}")
453-
return {key_converter(key): val_converter(val) for key, val in v.items()}
454-
elif hasattr(f.type, "__origin__") and (
455-
f.type.__origin__ is Union or f.type.__origin__ is Optional
462+
return {
463+
_restore_typed_primitive(key_t, key): _restore_typed_primitive(val_t, val)
464+
for key, val in v.items()
465+
}
466+
elif hasattr(t, "__origin__") and (
467+
t.__origin__ is Union or t.__origin__ is Optional
456468
):
457-
t_args = f.type.__args__
469+
t_args = t.__args__
458470
for t in t_args:
459-
if isclass(t) and issubclass(t, IndefiniteList):
460-
return IndefiniteList(v)
461-
elif isclass(t) and issubclass(t, CBORSerializable):
462-
try:
463-
return t.from_primitive(v)
464-
except DeserializeException:
465-
pass
466-
else:
467-
if not isclass(t) and hasattr(t, "__origin__"):
468-
t = t.__origin__
469-
if t in PRIMITIVE_TYPES and isinstance(v, t):
470-
return v
471+
try:
472+
return _restore_typed_primitive(t, v)
473+
except DeserializeException:
474+
pass
471475
raise DeserializeException(
472476
f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}."
473477
)
474-
return v
478+
raise DeserializeException(f"Cannot deserialize object: \n{v}\n to type {t}.")
475479

476480

477481
ArrayBase = TypeVar("ArrayBase", bound="ArrayCBORSerializable")
@@ -556,8 +560,8 @@ def to_shallow_primitive(self) -> List[Primitive]:
556560
return primitives
557561

558562
@classmethod
559-
@limit_primitive_type(list)
560-
def from_primitive(cls: Type[ArrayBase], values: list) -> ArrayBase:
563+
@limit_primitive_type(list, tuple)
564+
def from_primitive(cls: Type[ArrayBase], values: Union[list, tuple]) -> ArrayBase:
561565
"""Restore a primitive value to its original class type.
562566
563567
Args:

test/pycardano/test_serialization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
from dataclasses import dataclass, field
24
from test.pycardano.util import check_two_way_cbor
35

@@ -40,7 +42,7 @@ def test_array_cbor_serializable():
4042
@dataclass
4143
class Test1(ArrayCBORSerializable):
4244
a: str
43-
b: str = None
45+
b: Union[str, None] = None
4446

4547
@dataclass
4648
class Test2(ArrayCBORSerializable):
@@ -87,7 +89,7 @@ class Test1(MapCBORSerializable):
8789

8890
@dataclass
8991
class Test2(MapCBORSerializable):
90-
c: str = None
92+
c: Union[str, None] = None
9193
test1: Test1 = field(default_factory=Test1)
9294

9395
t = Test2(test1=Test1(a="a"))

0 commit comments

Comments
 (0)