From 17bde11bb59d84e4bee642bab827de10b9f8f5a4 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sat, 15 Apr 2023 21:41:55 -0700 Subject: [PATCH 1/2] Validate field type on serializing Validate each field in CBORSerializable during serialization. This will avoid bugs caused by incorrect type of field. For instance, when a field should be Address type, but being a string type instead, it would've been serialized successfully but failed to be deserialized. --- pycardano/serialization.py | 69 ++++++++++++++++++---- pycardano/transaction.py | 1 + test/pycardano/test_certificate.py | 5 +- test/pycardano/test_serialization.py | 86 +++++++++++++++++++++++++++- 4 files changed, 147 insertions(+), 14 deletions(-) diff --git a/pycardano/serialization.py b/pycardano/serialization.py index 79b1473d..34456a6c 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -11,7 +11,18 @@ from decimal import Decimal from functools import wraps from inspect import isclass -from typing import Any, Callable, List, Optional, Type, TypeVar, Union, get_type_hints +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, + get_type_hints, +) from cbor2 import CBOREncoder, CBORSimpleValue, CBORTag, dumps, loads, undefined from pprintpp import pformat @@ -254,7 +265,45 @@ def validate(self): Raises: InvalidDataException: When the data is invalid. """ - pass + type_hints = get_type_hints(self.__class__) + + def _check_recursive(value, type_hint): + if type_hint is Any: + return True + origin = getattr(type_hint, "__origin__", None) + if origin is None: + if isinstance(value, CBORSerializable): + value.validate() + return isinstance(value, type_hint) + elif origin is ClassVar: + return _check_recursive(value, type_hint.__args__[0]) + elif origin is Union: + return any(_check_recursive(value, arg) for arg in type_hint.__args__) + elif origin is Dict or isinstance(value, dict): + key_type, value_type = type_hint.__args__ + return all( + _check_recursive(k, key_type) and _check_recursive(v, value_type) + for k, v in value.items() + ) + elif origin in (list, set, tuple): + if value is None: + return True + args = type_hint.__args__ + if len(args) == 1: + return all(_check_recursive(item, args[0]) for item in value) + elif len(args) > 1: + return all( + _check_recursive(item, arg) for item, arg in zip(value, args) + ) + return True # We don't know how to check this type + + for field_name, field_type in type_hints.items(): + field_value = getattr(self, field_name) + if not _check_recursive(field_value, field_type): + raise TypeError( + f"Field '{field_name}' should be of type {field_type}, " + f"got {type(field_value)} instead." + ) def to_validated_primitive(self) -> Primitive: """Convert the instance and its elements to CBOR primitives recursively with data validated by :meth:`validate` @@ -505,8 +554,8 @@ class ArrayCBORSerializable(CBORSerializable): >>> t = Test2(c="c", test1=Test1(a="a")) >>> t Test2(c='c', test1=Test1(a='a', b=None)) - >>> cbor_hex = t.to_cbor() - >>> cbor_hex + >>> cbor_hex = t.to_cbor() # doctest: +SKIP + >>> cbor_hex # doctest: +SKIP '826163826161f6' >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP Test2(c='c', test1=Test1(a='a', b=None)) @@ -534,8 +583,8 @@ class ArrayCBORSerializable(CBORSerializable): Test2(c='c', test1=Test1(a='a', b=None)) >>> t.to_primitive() # Notice below that attribute "b" is not included in converted primitive. ['c', ['a']] - >>> cbor_hex = t.to_cbor() - >>> cbor_hex + >>> cbor_hex = t.to_cbor() # doctest: +SKIP + >>> cbor_hex # doctest: +SKIP '826163816161' >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP Test2(c='c', test1=Test1(a='a', b=None)) @@ -621,8 +670,8 @@ class MapCBORSerializable(CBORSerializable): Test2(c=None, test1=Test1(a='a', b='')) >>> t.to_primitive() {'c': None, 'test1': {'a': 'a', 'b': ''}} - >>> cbor_hex = t.to_cbor() - >>> cbor_hex + >>> cbor_hex = t.to_cbor() # doctest: +SKIP + >>> cbor_hex # doctest: +SKIP 'a26163f6657465737431a261616161616260' >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP Test2(c=None, test1=Test1(a='a', b='')) @@ -645,8 +694,8 @@ class MapCBORSerializable(CBORSerializable): Test2(c=None, test1=Test1(a='a', b='')) >>> t.to_primitive() {'1': {'0': 'a', '1': ''}} - >>> cbor_hex = t.to_cbor() - >>> cbor_hex + >>> cbor_hex = t.to_cbor() # doctest: +SKIP + >>> cbor_hex # doctest: +SKIP 'a16131a261306161613160' >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP Test2(c=None, test1=Test1(a='a', b='')) diff --git a/pycardano/transaction.py b/pycardano/transaction.py index bc564eae..56eacbe1 100644 --- a/pycardano/transaction.py +++ b/pycardano/transaction.py @@ -390,6 +390,7 @@ def __post_init__(self): self.amount = Value(self.amount) def validate(self): + super().validate() if isinstance(self.amount, int) and self.amount < 0: raise InvalidDataException( f"Transaction output cannot have negative amount of ADA or " diff --git a/test/pycardano/test_certificate.py b/test/pycardano/test_certificate.py index 8eb25bac..7786010c 100644 --- a/test/pycardano/test_certificate.py +++ b/test/pycardano/test_certificate.py @@ -1,5 +1,6 @@ from pycardano.address import Address from pycardano.certificate import ( + PoolKeyHash, StakeCredential, StakeDelegation, StakeDeregistration, @@ -43,7 +44,9 @@ def test_stake_deregistration(): def test_stake_delegation(): stake_credential = StakeCredential(TEST_ADDR.staking_part) - stake_delegation = StakeDelegation(stake_credential, b"1" * POOL_KEY_HASH_SIZE) + stake_delegation = StakeDelegation( + stake_credential, PoolKeyHash(b"1" * POOL_KEY_HASH_SIZE) + ) assert ( stake_delegation.to_cbor() diff --git a/test/pycardano/test_serialization.py b/test/pycardano/test_serialization.py index 2622b626..6df869b7 100644 --- a/test/pycardano/test_serialization.py +++ b/test/pycardano/test_serialization.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from test.pycardano.util import check_two_way_cbor -from typing import Any, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import pytest @@ -67,7 +67,7 @@ def test_array_cbor_serializable_optional_field(): @dataclass class Test1(ArrayCBORSerializable): a: str - b: str = field(default=None, metadata={"optional": True}) + b: Optional[str] = field(default=None, metadata={"optional": True}) @dataclass class Test2(ArrayCBORSerializable): @@ -104,7 +104,7 @@ class Test1(MapCBORSerializable): @dataclass class Test2(MapCBORSerializable): - c: str = field(default=None, metadata={"key": "0", "optional": True}) + c: Optional[str] = field(default=None, metadata={"key": "0", "optional": True}) test1: Test1 = field(default_factory=Test1, metadata={"key": "1"}) t = Test2(test1=Test1(a="a")) @@ -172,3 +172,83 @@ class Test1(MapCBORSerializable): t = Test1(a="a", b=1) check_two_way_cbor(t) + + +def test_wrong_primitive_type(): + @dataclass + class Test1(MapCBORSerializable): + a: str = "" + + with pytest.raises(TypeError): + Test1(a=1).to_cbor() + + +def test_wrong_union_type(): + @dataclass + class Test1(MapCBORSerializable): + a: Union[str, int] = "" + + with pytest.raises(TypeError): + Test1(a=1.0).to_cbor() + + +def test_wrong_optional_type(): + @dataclass + class Test1(MapCBORSerializable): + a: Optional[str] = "" + + with pytest.raises(TypeError): + Test1(a=1.0).to_cbor() + + +def test_wrong_list_type(): + @dataclass + class Test1(MapCBORSerializable): + a: List[str] = "" + + with pytest.raises(TypeError): + Test1(a=[1]).to_cbor() + + +def test_wrong_dict_type(): + @dataclass + class Test1(MapCBORSerializable): + a: Dict[str, int] = "" + + with pytest.raises(TypeError): + Test1(a={1: 1}).to_cbor() + + +def test_wrong_tuple_type(): + @dataclass + class Test1(MapCBORSerializable): + a: Tuple[str, int] = "" + + with pytest.raises(TypeError): + Test1(a=(1, 1)).to_cbor() + + +def test_wrong_set_type(): + @dataclass + class Test1(MapCBORSerializable): + a: Set[str] = "" + + with pytest.raises(TypeError): + Test1(a={1}).to_cbor() + + +def test_wrong_nested_type(): + @dataclass + class Test1(MapCBORSerializable): + a: str = "" + + @dataclass + class Test2(MapCBORSerializable): + a: Test1 = "" + b: Optional[Test1] = None + + with pytest.raises(TypeError): + Test2(a=1).to_cbor() + + with pytest.raises(TypeError): + Test2(a=Test1(a=1)).to_cbor() From b3bccb6d72727cd3328a341e08ac82994735ded8 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sun, 16 Apr 2023 13:43:57 -0700 Subject: [PATCH 2/2] Update pycardano/serialization.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Print violating field directly Co-authored-by: Niels Mündler --- pycardano/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycardano/serialization.py b/pycardano/serialization.py index 34456a6c..8d8d35e2 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -302,7 +302,7 @@ def _check_recursive(value, type_hint): if not _check_recursive(field_value, field_type): raise TypeError( f"Field '{field_name}' should be of type {field_type}, " - f"got {type(field_value)} instead." + f"got {repr(field_value)} instead." ) def to_validated_primitive(self) -> Primitive: