Skip to content

Commit c539d40

Browse files
Improve address type hint (#130)
* UPDATE. including address.py in mypy test * UPDATE. disable mypy str-bytes-safe check to allow bytes to be formatted automatically within a f-string * UPDATE. ensure only bytes values are evaluated for PointerAddress.from_primitive() * UPDATE. only allow str/bytes value for Address.from_primitive() * UPDATE. generalizing from_primitive() base method even more UPDATE. an internal use only limit_primitive_type() helper decorator is added to reduce repeated code UPDATE. specifying exact input type for child from_primitive()
1 parent 5340be1 commit c539d40

11 files changed

+119
-50
lines changed

pycardano/address.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from pycardano.hash import VERIFICATION_KEY_HASH_SIZE, ScriptHash, VerificationKeyHash
2121
from pycardano.network import Network
22-
from pycardano.serialization import CBORSerializable
22+
from pycardano.serialization import CBORSerializable, limit_primitive_type
2323

2424
__all__ = ["AddressType", "PointerAddress", "Address"]
2525

@@ -160,6 +160,7 @@ def to_primitive(self) -> bytes:
160160
return self.encode()
161161

162162
@classmethod
163+
@limit_primitive_type(bytes)
163164
def from_primitive(cls: Type[PointerAddress], value: bytes) -> PointerAddress:
164165
return cls.decode(value)
165166

@@ -339,6 +340,7 @@ def to_primitive(self) -> bytes:
339340
return bytes(self)
340341

341342
@classmethod
343+
@limit_primitive_type(bytes, str)
342344
def from_primitive(cls: Type[Address], value: Union[bytes, str]) -> Address:
343345
if isinstance(value, str):
344346
value = bytes(decode(value))

pycardano/hash.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Type, TypeVar, Union
44

5-
from pycardano.serialization import CBORSerializable
5+
from pycardano.serialization import CBORSerializable, limit_primitive_type
66

77
__all__ = [
88
"VERIFICATION_KEY_HASH_SIZE",
@@ -67,6 +67,7 @@ def to_primitive(self) -> bytes:
6767
return self.payload
6868

6969
@classmethod
70+
@limit_primitive_type(bytes, str)
7071
def from_primitive(cls: Type[T], value: Union[bytes, str]) -> T:
7172
if isinstance(value, str):
7273
value = bytes.fromhex(value)

pycardano/key.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pycardano.crypto.bip32 import BIP32ED25519PrivateKey, HDWallet
1515
from pycardano.exception import InvalidKeyTypeException
1616
from pycardano.hash import VERIFICATION_KEY_HASH_SIZE, VerificationKeyHash
17-
from pycardano.serialization import CBORSerializable
17+
from pycardano.serialization import CBORSerializable, limit_primitive_type
1818

1919
__all__ = [
2020
"Key",
@@ -62,6 +62,7 @@ def to_primitive(self) -> bytes:
6262
return self.payload
6363

6464
@classmethod
65+
@limit_primitive_type(bytes)
6566
def from_primitive(cls: Type["Key"], value: bytes) -> Key:
6667
return cls(value)
6768

pycardano/metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DictCBORSerializable,
1717
MapCBORSerializable,
1818
Primitive,
19+
limit_primitive_type,
1920
list_hook,
2021
)
2122

@@ -91,6 +92,7 @@ def to_primitive(self) -> Primitive:
9192
return CBORTag(AlonzoMetadata.TAG, super(AlonzoMetadata, self).to_primitive())
9293

9394
@classmethod
95+
@limit_primitive_type(CBORTag)
9496
def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:
9597
if not hasattr(value, "tag"):
9698
raise DeserializeException(

pycardano/nativescript.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
from pycardano.exception import DeserializeException
1212
from pycardano.hash import SCRIPT_HASH_SIZE, ScriptHash, VerificationKeyHash
13-
from pycardano.serialization import ArrayCBORSerializable, Primitive, list_hook
13+
from pycardano.serialization import (
14+
ArrayCBORSerializable,
15+
Primitive,
16+
limit_primitive_type,
17+
list_hook,
18+
)
1419
from pycardano.types import JsonDict
1520

1621
__all__ = [
@@ -30,22 +35,12 @@ class NativeScript(ArrayCBORSerializable):
3035
json_field: ClassVar[str]
3136

3237
@classmethod
38+
@limit_primitive_type(list)
3339
def from_primitive(
34-
cls: Type[NativeScript], value: Primitive
40+
cls: Type[NativeScript], value: list
3541
) -> Union[
3642
ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter
3743
]:
38-
if not isinstance(
39-
value,
40-
(
41-
list,
42-
tuple,
43-
),
44-
):
45-
raise DeserializeException(
46-
f"A list or a tuple is required for deserialization: {str(value)}"
47-
)
48-
4944
script_type: int = value[0]
5045
if script_type == ScriptPubkey._TYPE:
5146
return super(NativeScript, ScriptPubkey).from_primitive(value[1:])

pycardano/network.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from enum import Enum
66
from typing import Type
77

8-
from pycardano.exception import DeserializeException
9-
from pycardano.serialization import CBORSerializable, Primitive
8+
from pycardano.serialization import CBORSerializable, limit_primitive_type
109

1110
__all__ = ["Network"]
1211

@@ -23,9 +22,6 @@ def to_primitive(self) -> int:
2322
return self.value
2423

2524
@classmethod
26-
def from_primitive(cls: Type[Network], value: Primitive) -> Network:
27-
if not isinstance(value, int):
28-
raise DeserializeException(
29-
f"An integer value is required for deserialization: {str(value)}"
30-
)
25+
@limit_primitive_type(int)
26+
def from_primitive(cls: Type[Network], value: int) -> Network:
3127
return cls(value)

pycardano/plutus.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
from dataclasses import dataclass, field, fields
88
from enum import Enum
9-
from typing import Any, ClassVar, List, Optional, Type, Union
9+
from typing import Any, ClassVar, Optional, Type, Union
1010

1111
import cbor2
1212
from cbor2 import CBORTag
@@ -21,9 +21,9 @@
2121
CBORSerializable,
2222
DictCBORSerializable,
2323
IndefiniteList,
24-
Primitive,
2524
RawCBOR,
2625
default_encoder,
26+
limit_primitive_type,
2727
)
2828

2929
__all__ = [
@@ -66,6 +66,7 @@ def to_shallow_primitive(self) -> dict:
6666
return result
6767

6868
@classmethod
69+
@limit_primitive_type(dict)
6970
def from_primitive(cls: Type[CostModels], value: dict) -> CostModels:
7071
raise DeserializeException(
7172
"Deserialization of cost model is impossible, because some information is lost "
@@ -480,11 +481,8 @@ def to_shallow_primitive(self) -> CBORTag:
480481
return CBORTag(102, [self.CONSTR_ID, primitives])
481482

482483
@classmethod
484+
@limit_primitive_type(CBORTag)
483485
def from_primitive(cls: Type[PlutusData], value: CBORTag) -> PlutusData:
484-
if not isinstance(value, CBORTag):
485-
raise DeserializeException(
486-
f"Unexpected type: {CBORTag}. Got {type(value)} instead."
487-
)
488486
if value.tag == 102:
489487
tag = value.value[0]
490488
if tag != cls.CONSTR_ID:
@@ -643,6 +641,7 @@ def _dfs(obj):
643641
return _dfs(self.data)
644642

645643
@classmethod
644+
@limit_primitive_type(CBORTag)
646645
def from_primitive(cls: Type[RawPlutusData], value: CBORTag) -> RawPlutusData:
647646
return cls(value)
648647

@@ -675,6 +674,7 @@ def to_primitive(self) -> int:
675674
return self.value
676675

677676
@classmethod
677+
@limit_primitive_type(int)
678678
def from_primitive(cls: Type[RedeemerTag], value: int) -> RedeemerTag:
679679
return cls(value)
680680

@@ -704,7 +704,8 @@ class Redeemer(ArrayCBORSerializable):
704704
ex_units: ExecutionUnits = None
705705

706706
@classmethod
707-
def from_primitive(cls: Type[Redeemer], values: List[Primitive]) -> Redeemer:
707+
@limit_primitive_type(list)
708+
def from_primitive(cls: Type[Redeemer], values: list) -> Redeemer:
708709
if isinstance(values[2], CBORTag) and cls is Redeemer:
709710
values[2] = RawPlutusData.from_primitive(values[2])
710711
redeemer = super(Redeemer, cls).from_primitive(

pycardano/serialization.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dataclasses import Field, dataclass, fields
99
from datetime import datetime
1010
from decimal import Decimal
11+
from functools import wraps
1112
from inspect import isclass
1213
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints
1314

@@ -105,6 +106,31 @@ class RawCBOR:
105106
`Cbor2 encoder <https://cbor2.readthedocs.io/en/latest/modules/encoder.html>`_ directly.
106107
"""
107108

109+
110+
def limit_primitive_type(*allowed_types):
111+
"""
112+
A helper function to validate primitive type given to from_primitive class methods
113+
114+
Not exposed to public by intention.
115+
"""
116+
117+
def decorator(func):
118+
@wraps(func)
119+
def wrapper(cls, value: Primitive):
120+
if not isinstance(value, allowed_types):
121+
allowed_types_str = [
122+
allowed_type.__name__ for allowed_type in allowed_types
123+
]
124+
raise DeserializeException(
125+
f"{allowed_types_str} typed value is required for deserialization. Got {type(value)}: {value}"
126+
)
127+
return func(cls, value)
128+
129+
return wrapper
130+
131+
return decorator
132+
133+
108134
CBORBase = TypeVar("CBORBase", bound="CBORSerializable")
109135

110136

@@ -245,7 +271,7 @@ def to_validated_primitive(self) -> Primitive:
245271
return self.to_primitive()
246272

247273
@classmethod
248-
def from_primitive(cls: Type[CBORBase], value: Primitive) -> CBORBase:
274+
def from_primitive(cls: Type[CBORBase], value: Any) -> CBORBase:
249275
"""Turn a CBOR primitive to its original class type.
250276
251277
Args:
@@ -407,7 +433,7 @@ def _restore_dataclass_field(
407433
elif t in PRIMITIVE_TYPES and isinstance(v, t):
408434
return v
409435
raise DeserializeException(
410-
f"Cannot deserialize object: \n{str(v)}\n in any valid type from {t_args}."
436+
f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}."
411437
)
412438
return v
413439

@@ -494,7 +520,8 @@ def to_shallow_primitive(self) -> List[Primitive]:
494520
return primitives
495521

496522
@classmethod
497-
def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase:
523+
@limit_primitive_type(list)
524+
def from_primitive(cls: Type[ArrayBase], values: list) -> ArrayBase:
498525
"""Restore a primitive value to its original class type.
499526
500527
Args:
@@ -508,10 +535,6 @@ def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase:
508535
DeserializeException: When the object could not be restored from primitives.
509536
"""
510537
all_fields = [f for f in fields(cls) if f.init]
511-
if type(values) != list:
512-
raise DeserializeException(
513-
f"Expect input value to be a list, got a {type(values)} instead."
514-
)
515538

516539
restored_vals = []
517540
type_hints = get_type_hints(cls)
@@ -606,7 +629,8 @@ def to_shallow_primitive(self) -> Primitive:
606629
return primitives
607630

608631
@classmethod
609-
def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
632+
@limit_primitive_type(dict)
633+
def from_primitive(cls: Type[MapBase], values: dict) -> MapBase:
610634
"""Restore a primitive value to its original class type.
611635
612636
Args:
@@ -620,10 +644,6 @@ def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
620644
:class:`pycardano.exception.DeserializeException`: When the object could not be restored from primitives.
621645
"""
622646
all_fields = {f.metadata.get("key", f.name): f for f in fields(cls) if f.init}
623-
if type(values) != dict:
624-
raise DeserializeException(
625-
f"Expect input value to be a dict, got a {type(values)} instead."
626-
)
627647

628648
kwargs = {}
629649
type_hints = get_type_hints(cls)
@@ -725,7 +745,8 @@ def _get_sortable_val(key):
725745
return dict(sorted(self.data.items(), key=lambda x: _get_sortable_val(x[0])))
726746

727747
@classmethod
728-
def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase:
748+
@limit_primitive_type(dict)
749+
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
729750
"""Restore a primitive value to its original class type.
730751
731752
Args:
@@ -739,11 +760,7 @@ def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase:
739760
DeserializeException: When the object could not be restored from primitives.
740761
"""
741762
if not value:
742-
raise DeserializeException(f"Cannot accept empty value {str(value)}.")
743-
if not isinstance(value, dict):
744-
raise DeserializeException(
745-
f"A dictionary value is required for deserialization: {str(value)}"
746-
)
763+
raise DeserializeException(f"Cannot accept empty value {value}.")
747764

748765
restored = cls()
749766
for k, v in value.items():

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ profile = "black"
6363

6464
[tool.mypy]
6565
ignore_missing_imports = true
66+
disable_error_code = ["str-bytes-safe"]
6667
python_version = 3.7
6768
exclude = [
6869
'^pycardano/cip/cip8.py$',
6970
'^pycardano/crypto/bech32.py$',
70-
'^pycardano/address.py$',
7171
'^pycardano/certificate.py$',
7272
'^pycardano/coinselection.py$',
7373
'^pycardano/exception.py$',

test/pycardano/test_address.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from pycardano.address import Address
1+
from unittest import TestCase
2+
3+
from pycardano.address import Address, PointerAddress
4+
from pycardano.exception import DeserializeException
25
from pycardano.key import PaymentVerificationKey
36
from pycardano.network import Network
47

@@ -15,3 +18,27 @@ def test_payment_addr():
1518
Address(vk.hash(), network=Network.TESTNET).encode()
1619
== "addr_test1vr2p8st5t5cxqglyjky7vk98k7jtfhdpvhl4e97cezuhn0cqcexl7"
1720
)
21+
22+
23+
class PointerAddressTest(TestCase):
24+
def test_from_primitive_invalid_value(self):
25+
with self.assertRaises(DeserializeException):
26+
PointerAddress.from_primitive(1)
27+
28+
with self.assertRaises(DeserializeException):
29+
PointerAddress.from_primitive([])
30+
31+
with self.assertRaises(DeserializeException):
32+
PointerAddress.from_primitive({})
33+
34+
35+
class AddressTest(TestCase):
36+
def test_from_primitive_invalid_value(self):
37+
with self.assertRaises(DeserializeException):
38+
Address.from_primitive(1)
39+
40+
with self.assertRaises(DeserializeException):
41+
Address.from_primitive([])
42+
43+
with self.assertRaises(DeserializeException):
44+
Address.from_primitive({})

0 commit comments

Comments
 (0)