diff --git a/pycardano/plutus.py b/pycardano/plutus.py index d0feed08..6e6cab5b 100644 --- a/pycardano/plutus.py +++ b/pycardano/plutus.py @@ -589,6 +589,35 @@ def _dfs(obj): raise DeserializeException( f"Unexpected data structure: {f}." ) + elif ( + hasattr(f_info.type, "__origin__") + and f_info.type.__origin__ is dict + ): + t_args = f_info.type.__args__ + if len(t_args) != 2: + raise DeserializeException( + "Dict type with wrong number of arguments" + ) + if "map" not in f: + raise DeserializeException( + f'Expected type "map" in object but got "{f}"' + ) + key_t = t_args[0] + val_t = t_args[1] + if inspect.isclass(key_t) and issubclass(key_t, PlutusData): + key_convert = key_t.from_dict + else: + key_convert = _dfs + if inspect.isclass(val_t) and issubclass(val_t, PlutusData): + val_convert = val_t.from_dict + else: + val_convert = _dfs + converted_fields.append( + { + key_convert(pair["k"]): val_convert(pair["v"]) + for pair in f["map"] + } + ) else: converted_fields.append(_dfs(f)) return cls(*converted_fields) diff --git a/pycardano/serialization.py b/pycardano/serialization.py index e4dcc32f..4e226fa3 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -37,6 +37,10 @@ ] +def _identity(x): + return x + + class IndefiniteList(UserList): def __init__(self, li: Primitive): # type: ignore super().__init__(li) # type: ignore @@ -415,6 +419,25 @@ def _restore_dataclass_field( return f.type.from_primitive(v) elif isclass(f.type) and issubclass(f.type, IndefiniteList): return IndefiniteList(v) + elif hasattr(f.type, "__origin__") and (f.type.__origin__ is dict): + t_args = f.type.__args__ + if len(t_args) != 2: + raise DeserializeException( + f"Dict types need exactly two type arguments, but got {t_args}" + ) + key_t = t_args[0] + val_t = t_args[1] + if isclass(key_t) and issubclass(key_t, CBORSerializable): + key_converter = key_t.from_primitive + else: + key_converter = _identity + if isclass(val_t) and issubclass(val_t, CBORSerializable): + val_converter = val_t.from_primitive + else: + val_converter = _identity + if not isinstance(v, dict): + raise DeserializeException(f"Expected dict type but got {type(v)}") + return {key_converter(key): val_converter(val) for key, val in v.items()} elif hasattr(f.type, "__origin__") and ( f.type.__origin__ is Union or f.type.__origin__ is Optional ): diff --git a/test/pycardano/test_plutus.py b/test/pycardano/test_plutus.py index 8e0fdcf1..600f45c5 100644 --- a/test/pycardano/test_plutus.py +++ b/test/pycardano/test_plutus.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass +import unittest + from test.pycardano.util import check_two_way_cbor -from typing import Union, Optional +from typing import Union, Dict import pytest @@ -39,6 +41,13 @@ class LargestTest(PlutusData): CONSTR_ID = 9 +@dataclass +class DictTest(PlutusData): + CONSTR_ID = 3 + + a: Dict[int, LargestTest] + + @dataclass class VestingParam(PlutusData): CONSTR_ID = 1 @@ -95,6 +104,29 @@ def test_plutus_data_json(): assert my_vesting == VestingParam.from_json(encoded_json) +def test_plutus_data_json_dict(): + test = DictTest({0: LargestTest(), 1: LargestTest()}) + + encoded_json = test.to_json(separators=(",", ":")) + + assert ( + '{"constructor":3,"fields":[{"map":[{"v":{"constructor":9,"fields":[]},"k":{"int":0}},{"v":{"constructor":9,"fields":[]},"k":{"int":1}}]}]}' + == encoded_json + ) + + assert test == DictTest.from_json(encoded_json) + + +def test_plutus_data_cbor_dict(): + test = DictTest({0: LargestTest(), 1: LargestTest()}) + + encoded_cbor = test.to_cbor() + + assert "d87c9fa200d905028001d9050280ff" == encoded_cbor + + assert test == DictTest.from_cbor(encoded_cbor) + + def test_plutus_data_to_json_wrong_type(): test = MyTest(123, b"1234", IndefiniteList([4, 5, 6]), {1: b"1", 2: b"2"}) test.a = "123"