Skip to content

Commit 6ec8623

Browse files
authored
Merge pull request #208 from cffls/main
Validate field type on serializing
2 parents 95881a0 + b3bccb6 commit 6ec8623

File tree

4 files changed

+147
-14
lines changed

4 files changed

+147
-14
lines changed

pycardano/serialization.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111
from decimal import Decimal
1212
from functools import wraps
1313
from inspect import isclass
14-
from typing import Any, Callable, List, Optional, Type, TypeVar, Union, get_type_hints
14+
from typing import (
15+
Any,
16+
Callable,
17+
ClassVar,
18+
Dict,
19+
List,
20+
Optional,
21+
Type,
22+
TypeVar,
23+
Union,
24+
get_type_hints,
25+
)
1526

1627
from cbor2 import CBOREncoder, CBORSimpleValue, CBORTag, dumps, loads, undefined
1728
from pprintpp import pformat
@@ -254,7 +265,45 @@ def validate(self):
254265
Raises:
255266
InvalidDataException: When the data is invalid.
256267
"""
257-
pass
268+
type_hints = get_type_hints(self.__class__)
269+
270+
def _check_recursive(value, type_hint):
271+
if type_hint is Any:
272+
return True
273+
origin = getattr(type_hint, "__origin__", None)
274+
if origin is None:
275+
if isinstance(value, CBORSerializable):
276+
value.validate()
277+
return isinstance(value, type_hint)
278+
elif origin is ClassVar:
279+
return _check_recursive(value, type_hint.__args__[0])
280+
elif origin is Union:
281+
return any(_check_recursive(value, arg) for arg in type_hint.__args__)
282+
elif origin is Dict or isinstance(value, dict):
283+
key_type, value_type = type_hint.__args__
284+
return all(
285+
_check_recursive(k, key_type) and _check_recursive(v, value_type)
286+
for k, v in value.items()
287+
)
288+
elif origin in (list, set, tuple):
289+
if value is None:
290+
return True
291+
args = type_hint.__args__
292+
if len(args) == 1:
293+
return all(_check_recursive(item, args[0]) for item in value)
294+
elif len(args) > 1:
295+
return all(
296+
_check_recursive(item, arg) for item, arg in zip(value, args)
297+
)
298+
return True # We don't know how to check this type
299+
300+
for field_name, field_type in type_hints.items():
301+
field_value = getattr(self, field_name)
302+
if not _check_recursive(field_value, field_type):
303+
raise TypeError(
304+
f"Field '{field_name}' should be of type {field_type}, "
305+
f"got {repr(field_value)} instead."
306+
)
258307

259308
def to_validated_primitive(self) -> Primitive:
260309
"""Convert the instance and its elements to CBOR primitives recursively with data validated by :meth:`validate`
@@ -505,8 +554,8 @@ class ArrayCBORSerializable(CBORSerializable):
505554
>>> t = Test2(c="c", test1=Test1(a="a"))
506555
>>> t
507556
Test2(c='c', test1=Test1(a='a', b=None))
508-
>>> cbor_hex = t.to_cbor()
509-
>>> cbor_hex
557+
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
558+
>>> cbor_hex # doctest: +SKIP
510559
'826163826161f6'
511560
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
512561
Test2(c='c', test1=Test1(a='a', b=None))
@@ -534,8 +583,8 @@ class ArrayCBORSerializable(CBORSerializable):
534583
Test2(c='c', test1=Test1(a='a', b=None))
535584
>>> t.to_primitive() # Notice below that attribute "b" is not included in converted primitive.
536585
['c', ['a']]
537-
>>> cbor_hex = t.to_cbor()
538-
>>> cbor_hex
586+
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
587+
>>> cbor_hex # doctest: +SKIP
539588
'826163816161'
540589
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
541590
Test2(c='c', test1=Test1(a='a', b=None))
@@ -621,8 +670,8 @@ class MapCBORSerializable(CBORSerializable):
621670
Test2(c=None, test1=Test1(a='a', b=''))
622671
>>> t.to_primitive()
623672
{'c': None, 'test1': {'a': 'a', 'b': ''}}
624-
>>> cbor_hex = t.to_cbor()
625-
>>> cbor_hex
673+
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
674+
>>> cbor_hex # doctest: +SKIP
626675
'a26163f6657465737431a261616161616260'
627676
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
628677
Test2(c=None, test1=Test1(a='a', b=''))
@@ -645,8 +694,8 @@ class MapCBORSerializable(CBORSerializable):
645694
Test2(c=None, test1=Test1(a='a', b=''))
646695
>>> t.to_primitive()
647696
{'1': {'0': 'a', '1': ''}}
648-
>>> cbor_hex = t.to_cbor()
649-
>>> cbor_hex
697+
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
698+
>>> cbor_hex # doctest: +SKIP
650699
'a16131a261306161613160'
651700
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
652701
Test2(c=None, test1=Test1(a='a', b=''))

pycardano/transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def __post_init__(self):
390390
self.amount = Value(self.amount)
391391

392392
def validate(self):
393+
super().validate()
393394
if isinstance(self.amount, int) and self.amount < 0:
394395
raise InvalidDataException(
395396
f"Transaction output cannot have negative amount of ADA or "

test/pycardano/test_certificate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pycardano.address import Address
22
from pycardano.certificate import (
3+
PoolKeyHash,
34
StakeCredential,
45
StakeDelegation,
56
StakeDeregistration,
@@ -43,7 +44,9 @@ def test_stake_deregistration():
4344

4445
def test_stake_delegation():
4546
stake_credential = StakeCredential(TEST_ADDR.staking_part)
46-
stake_delegation = StakeDelegation(stake_credential, b"1" * POOL_KEY_HASH_SIZE)
47+
stake_delegation = StakeDelegation(
48+
stake_credential, PoolKeyHash(b"1" * POOL_KEY_HASH_SIZE)
49+
)
4750

4851
assert (
4952
stake_delegation.to_cbor()

test/pycardano/test_serialization.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass, field
22
from test.pycardano.util import check_two_way_cbor
3-
from typing import Any, Union
3+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
44

55
import pytest
66

@@ -67,7 +67,7 @@ def test_array_cbor_serializable_optional_field():
6767
@dataclass
6868
class Test1(ArrayCBORSerializable):
6969
a: str
70-
b: str = field(default=None, metadata={"optional": True})
70+
b: Optional[str] = field(default=None, metadata={"optional": True})
7171

7272
@dataclass
7373
class Test2(ArrayCBORSerializable):
@@ -104,7 +104,7 @@ class Test1(MapCBORSerializable):
104104

105105
@dataclass
106106
class Test2(MapCBORSerializable):
107-
c: str = field(default=None, metadata={"key": "0", "optional": True})
107+
c: Optional[str] = field(default=None, metadata={"key": "0", "optional": True})
108108
test1: Test1 = field(default_factory=Test1, metadata={"key": "1"})
109109

110110
t = Test2(test1=Test1(a="a"))
@@ -172,3 +172,83 @@ class Test1(MapCBORSerializable):
172172
t = Test1(a="a", b=1)
173173

174174
check_two_way_cbor(t)
175+
176+
177+
def test_wrong_primitive_type():
178+
@dataclass
179+
class Test1(MapCBORSerializable):
180+
a: str = ""
181+
182+
with pytest.raises(TypeError):
183+
Test1(a=1).to_cbor()
184+
185+
186+
def test_wrong_union_type():
187+
@dataclass
188+
class Test1(MapCBORSerializable):
189+
a: Union[str, int] = ""
190+
191+
with pytest.raises(TypeError):
192+
Test1(a=1.0).to_cbor()
193+
194+
195+
def test_wrong_optional_type():
196+
@dataclass
197+
class Test1(MapCBORSerializable):
198+
a: Optional[str] = ""
199+
200+
with pytest.raises(TypeError):
201+
Test1(a=1.0).to_cbor()
202+
203+
204+
def test_wrong_list_type():
205+
@dataclass
206+
class Test1(MapCBORSerializable):
207+
a: List[str] = ""
208+
209+
with pytest.raises(TypeError):
210+
Test1(a=[1]).to_cbor()
211+
212+
213+
def test_wrong_dict_type():
214+
@dataclass
215+
class Test1(MapCBORSerializable):
216+
a: Dict[str, int] = ""
217+
218+
with pytest.raises(TypeError):
219+
Test1(a={1: 1}).to_cbor()
220+
221+
222+
def test_wrong_tuple_type():
223+
@dataclass
224+
class Test1(MapCBORSerializable):
225+
a: Tuple[str, int] = ""
226+
227+
with pytest.raises(TypeError):
228+
Test1(a=(1, 1)).to_cbor()
229+
230+
231+
def test_wrong_set_type():
232+
@dataclass
233+
class Test1(MapCBORSerializable):
234+
a: Set[str] = ""
235+
236+
with pytest.raises(TypeError):
237+
Test1(a={1}).to_cbor()
238+
239+
240+
def test_wrong_nested_type():
241+
@dataclass
242+
class Test1(MapCBORSerializable):
243+
a: str = ""
244+
245+
@dataclass
246+
class Test2(MapCBORSerializable):
247+
a: Test1 = ""
248+
b: Optional[Test1] = None
249+
250+
with pytest.raises(TypeError):
251+
Test2(a=1).to_cbor()
252+
253+
with pytest.raises(TypeError):
254+
Test2(a=Test1(a=1)).to_cbor()

0 commit comments

Comments
 (0)