Skip to content

Commit 8d4a50d

Browse files
Added IndefiniteDecoder for round trip plutusdata serialization
1 parent a312c08 commit 8d4a50d

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

pycardano/serialization.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from io import BytesIO
56
import re
67
import typing
78
from collections import OrderedDict, UserList, defaultdict
@@ -21,6 +22,7 @@
2122
Iterable,
2223
List,
2324
Optional,
25+
Sequence,
2426
Set,
2527
Type,
2628
TypeVar,
@@ -41,6 +43,7 @@
4143
pass
4244

4345
from cbor2 import (
46+
CBORDecoder,
4447
CBOREncoder,
4548
CBORSimpleValue,
4649
CBORTag,
@@ -199,6 +202,17 @@ def wrapper(cls, value: Primitive):
199202
CBORBase = TypeVar("CBORBase", bound="CBORSerializable")
200203

201204

205+
class IndefiniteDecoder(CBORDecoder):
206+
def decode_array(self, subtype: int) -> Sequence[Any]:
207+
# Major tag 4
208+
length = self._decode_length(subtype, allow_indefinite=True)
209+
210+
if length is None:
211+
return IndefiniteList(super().decode_array(subtype=subtype))
212+
else:
213+
return super().decode_array(subtype=subtype)
214+
215+
202216
def default_encoder(
203217
encoder: CBOREncoder, value: Union[CBORSerializable, IndefiniteList]
204218
):
@@ -265,7 +279,7 @@ class CBORSerializable:
265279
does not refer to itself, which could cause infinite loops.
266280
"""
267281

268-
def to_shallow_primitive(self) -> Primitive:
282+
def to_shallow_primitive(self) -> Union[Primitive, CBORSerializable]:
269283
"""
270284
Convert the instance to a CBOR primitive. If the primitive is a container, e.g. list, dict, the type of
271285
its elements could be either a Primitive or a CBORSerializable.
@@ -516,7 +530,10 @@ def from_cbor(cls, payload: Union[str, bytes]) -> CBORSerializable:
516530
"""
517531
if type(payload) is str:
518532
payload = bytes.fromhex(payload)
519-
value = loads(payload) # type: ignore
533+
534+
with BytesIO(payload) as fp:
535+
value = IndefiniteDecoder(fp).decode()
536+
520537
return cls.from_primitive(value)
521538

522539
def __repr__(self):
@@ -580,10 +597,14 @@ def _restore_typed_primitive(
580597
raise DeserializeException(
581598
f"List types need exactly one type argument, but got {t_args}"
582599
)
583-
t = t_args[0]
584-
if not isinstance(v, list):
600+
t_subtype = t_args[0]
601+
if not isinstance(v, (list, IndefiniteList)):
585602
raise DeserializeException(f"Expected type list but got {type(v)}")
586-
return IndefiniteList([_restore_typed_primitive(t, w) for w in v])
603+
v_list = [_restore_typed_primitive(t_subtype, w) for w in v]
604+
if t == IndefiniteList:
605+
return IndefiniteList(v_list)
606+
else:
607+
return v_list
587608
elif isclass(t) and t == ByteString:
588609
if not isinstance(v, bytes):
589610
raise DeserializeException(f"Expected type bytes but got {type(v)}")

test/pycardano/test_serialization.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
VerificationKeyWitness,
3131
)
3232
from pycardano.exception import DeserializeException, SerializeException
33-
from pycardano.plutus import PlutusV1Script, PlutusV2Script
33+
from pycardano.plutus import PlutusData, PlutusV1Script, PlutusV2Script
3434
from pycardano.serialization import (
3535
ArrayCBORSerializable,
3636
ByteString,
@@ -368,6 +368,31 @@ class Test1(CBORSerializable):
368368
obj.validate()
369369

370370

371+
@pytest.mark.xfail
372+
def test_datum_raw_round_trip():
373+
@dataclass
374+
class TestDatum(PlutusData):
375+
CONSTR_ID = 0
376+
a: int
377+
b: List[bytes]
378+
379+
datum = TestDatum(a=1, b=[b"test", b"datum"])
380+
restored = RawPlutusData.from_cbor(datum.to_cbor())
381+
assert datum.to_cbor_hex() == restored.to_cbor_hex()
382+
383+
384+
def test_datum_round_trip():
385+
@dataclass
386+
class TestDatum(PlutusData):
387+
CONSTR_ID = 0
388+
a: int
389+
b: List[bytes]
390+
391+
datum = TestDatum(a=1, b=[b"test", b"datum"])
392+
restored = TestDatum.from_cbor(datum.to_cbor())
393+
assert datum.to_cbor_hex() == restored.to_cbor_hex()
394+
395+
371396
def test_wrong_primitive_type():
372397
@dataclass
373398
class Test1(MapCBORSerializable):

0 commit comments

Comments
 (0)