Skip to content

Commit d68fb5a

Browse files
committed
Fix static typing
Fixed the following files: pycardano/certificate.py pycardano/coinselection.py pycardano/key.py pycardano/metadata.py pycardano/plutus.py pycardano/serialization.py pycardano/transaction.py pycardano/txbuilder.py pycardano/utils.py pycardano/witness.py
1 parent 5e37db1 commit d68fb5a

12 files changed

+206
-149
lines changed

pycardano/certificate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Union
2+
from typing import Optional, Union
33

44
from pycardano.hash import PoolKeyHash, ScriptHash, VerificationKeyHash
55
from pycardano.serialization import ArrayCBORSerializable
@@ -16,7 +16,7 @@
1616
@dataclass(repr=False)
1717
class StakeCredential(ArrayCBORSerializable):
1818

19-
_CODE: int = field(init=False, default=None)
19+
_CODE: Optional[int] = field(init=False, default=None)
2020

2121
credential: Union[VerificationKeyHash, ScriptHash]
2222

pycardano/coinselection.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def select(
3333
utxos: List[UTxO],
3434
outputs: List[TransactionOutput],
3535
context: ChainContext,
36-
max_input_count: int = None,
37-
include_max_fee: bool = True,
38-
respect_min_utxo: bool = True,
36+
max_input_count: Optional[int] = None,
37+
include_max_fee: Optional[bool] = True,
38+
respect_min_utxo: Optional[bool] = True,
3939
) -> Tuple[List[UTxO], Value]:
4040
"""From an input list of UTxOs, select a subset of UTxOs whose sum (including ADA and multi-assets)
4141
is equal to or larger than the sum of a set of outputs.
@@ -115,7 +115,11 @@ def select(
115115
if change.coin < min_change_amount:
116116
additional, _ = self.select(
117117
available,
118-
[TransactionOutput(None, min_change_amount - change.coin)],
118+
[
119+
TransactionOutput(
120+
_FAKE_ADDR, Value(min_change_amount - change.coin)
121+
)
122+
],
119123
context,
120124
max_input_count - len(selected) if max_input_count else None,
121125
include_max_fee=False,
@@ -230,13 +234,13 @@ def _improve(
230234
remaining: List[UTxO],
231235
ideal: Value,
232236
upper_bound: Value,
233-
max_input_count: int,
237+
max_input_count: Optional[int] = None,
234238
):
235239
if not remaining or self._find_diff_by_former(ideal, selected_amount) <= 0:
236240
# In case where there is no remaining UTxOs or we already selected more than ideal,
237241
# we cannot improve by randomly adding more UTxOs, therefore return immediate.
238242
return
239-
if max_input_count and len(selected) > max_input_count:
243+
if max_input_count is not None and len(selected) > max_input_count:
240244
raise MaxInputCountExceededException(
241245
f"Max input count: {max_input_count} exceeded!"
242246
)
@@ -269,9 +273,9 @@ def select(
269273
utxos: List[UTxO],
270274
outputs: List[TransactionOutput],
271275
context: ChainContext,
272-
max_input_count: int = None,
273-
include_max_fee: bool = True,
274-
respect_min_utxo: bool = True,
276+
max_input_count: Optional[int] = None,
277+
include_max_fee: Optional[bool] = True,
278+
respect_min_utxo: Optional[bool] = True,
275279
) -> Tuple[List[UTxO], Value]:
276280
# Shallow copy the list
277281
remaining = list(utxos)
@@ -284,7 +288,7 @@ def select(
284288
request_sorted = sorted(assets, key=self._get_single_asset_val, reverse=True)
285289

286290
# Phase 1 - random select
287-
selected = []
291+
selected: List[UTxO] = []
288292
selected_amount = Value()
289293
for r in request_sorted:
290294
self._random_select_subset(r, remaining, selected, selected_amount)
@@ -321,7 +325,11 @@ def select(
321325
if change.coin < min_change_amount:
322326
additional, _ = self.select(
323327
remaining,
324-
[TransactionOutput(None, min_change_amount - change.coin)],
328+
[
329+
TransactionOutput(
330+
_FAKE_ADDR, Value(min_change_amount - change.coin)
331+
)
332+
],
325333
context,
326334
max_input_count - len(selected) if max_input_count else None,
327335
include_max_fee=False,

pycardano/key.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
import os
7-
from typing import Type
7+
from typing import Optional, Type
88

99
from nacl.encoding import RawEncoder
1010
from nacl.hash import blake2b
@@ -41,7 +41,12 @@ class Key(CBORSerializable):
4141
KEY_TYPE = ""
4242
DESCRIPTION = ""
4343

44-
def __init__(self, payload: bytes, key_type: str = None, description: str = None):
44+
def __init__(
45+
self,
46+
payload: bytes,
47+
key_type: Optional[str] = None,
48+
description: Optional[str] = None,
49+
):
4550
self._payload = payload
4651
self._key_type = key_type or self.KEY_TYPE
4752
self._description = description or self.KEY_TYPE
@@ -83,7 +88,7 @@ def to_json(self) -> str:
8388
)
8489

8590
@classmethod
86-
def from_json(cls, data: str, validate_type=False) -> Key:
91+
def from_json(cls: Type[Key], data: str, validate_type=False) -> Key:
8792
"""Restore a key from a JSON string.
8893
8994
Args:
@@ -105,8 +110,12 @@ def from_json(cls, data: str, validate_type=False) -> Key:
105110
f"Expect key type: {cls.KEY_TYPE}, got {obj['type']} instead."
106111
)
107112

113+
k = cls.from_cbor(obj["cborHex"])
114+
115+
assert isinstance(k, cls)
116+
108117
return cls(
109-
cls.from_cbor(obj["cborHex"]).payload,
118+
k.payload,
110119
key_type=obj["type"],
111120
description=obj["description"],
112121
)
@@ -244,19 +253,19 @@ class PaymentExtendedVerificationKey(ExtendedVerificationKey):
244253

245254

246255
class PaymentKeyPair:
247-
def __init__(
248-
self, signing_key: PaymentSigningKey, verification_key: PaymentVerificationKey
249-
):
256+
def __init__(self, signing_key: SigningKey, verification_key: VerificationKey):
250257
self.signing_key = signing_key
251258
self.verification_key = verification_key
252259

253260
@classmethod
254-
def generate(cls) -> PaymentKeyPair:
261+
def generate(cls: Type[PaymentKeyPair]) -> PaymentKeyPair:
255262
signing_key = PaymentSigningKey.generate()
256263
return cls.from_signing_key(signing_key)
257264

258265
@classmethod
259-
def from_signing_key(cls, signing_key: PaymentSigningKey) -> PaymentKeyPair:
266+
def from_signing_key(
267+
cls: Type[PaymentKeyPair], signing_key: SigningKey
268+
) -> PaymentKeyPair:
260269
return cls(signing_key, PaymentVerificationKey.from_signing_key(signing_key))
261270

262271
def __eq__(self, other):
@@ -288,17 +297,17 @@ class StakeExtendedVerificationKey(ExtendedVerificationKey):
288297

289298

290299
class StakeKeyPair:
291-
def __init__(
292-
self, signing_key: StakeSigningKey, verification_key: StakeVerificationKey
293-
):
300+
def __init__(self, signing_key: SigningKey, verification_key: VerificationKey):
294301
self.signing_key = signing_key
295302
self.verification_key = verification_key
296303

297304
@classmethod
298-
def generate(cls) -> StakeKeyPair:
305+
def generate(cls: Type[StakeKeyPair]) -> StakeKeyPair:
299306
signing_key = StakeSigningKey.generate()
300307
return cls.from_signing_key(signing_key)
301308

302309
@classmethod
303-
def from_signing_key(cls, signing_key: StakeSigningKey) -> StakeKeyPair:
310+
def from_signing_key(
311+
cls: Type[StakeKeyPair], signing_key: SigningKey
312+
) -> StakeKeyPair:
304313
return cls(signing_key, StakeVerificationKey.from_signing_key(signing_key))

pycardano/metadata.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from typing import Any, ClassVar, List, Type, Union
4+
from typing import Any, ClassVar, List, Optional, Type, Union
55

66
from cbor2 import CBORTag
77
from nacl.encoding import RawEncoder
@@ -20,7 +20,7 @@
2020
list_hook,
2121
)
2222

23-
__all__ = ["Metadata", "ShellayMarryMetadata", "AlonzoMetadata", "AuxiliaryData"]
23+
__all__ = ["Metadata", "ShelleyMarryMetadata", "AlonzoMetadata", "AuxiliaryData"]
2424

2525

2626
class Metadata(DictCBORSerializable):
@@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs):
6868

6969

7070
@dataclass
71-
class ShellayMarryMetadata(ArrayCBORSerializable):
71+
class ShelleyMarryMetadata(ArrayCBORSerializable):
7272
metadata: Metadata
73-
native_scripts: List[NativeScript] = field(
73+
native_scripts: Optional[List[NativeScript]] = field(
7474
default=None, metadata={"object_hook": list_hook(NativeScript)}
7575
)
7676

@@ -79,12 +79,14 @@ class ShellayMarryMetadata(ArrayCBORSerializable):
7979
class AlonzoMetadata(MapCBORSerializable):
8080
TAG: ClassVar[int] = 259
8181

82-
metadata: Metadata = field(default=None, metadata={"optional": True, "key": 0})
83-
native_scripts: List[NativeScript] = field(
82+
metadata: Optional[Metadata] = field(
83+
default=None, metadata={"optional": True, "key": 0}
84+
)
85+
native_scripts: Optional[List[NativeScript]] = field(
8486
default=None,
8587
metadata={"optional": True, "key": 1, "object_hook": list_hook(NativeScript)},
8688
)
87-
plutus_scripts: List[bytes] = field(
89+
plutus_scripts: Optional[List[bytes]] = field(
8890
default=None, metadata={"optional": True, "key": 2}
8991
)
9092

@@ -107,23 +109,23 @@ def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:
107109

108110
@dataclass
109111
class AuxiliaryData(CBORSerializable):
110-
data: Union[Metadata, ShellayMarryMetadata, AlonzoMetadata]
112+
data: Union[Metadata, ShelleyMarryMetadata, AlonzoMetadata]
111113

112114
def to_primitive(self) -> Primitive:
113115
return self.data.to_primitive()
114116

115117
@classmethod
116118
def from_primitive(cls: Type[AuxiliaryData], value: Primitive) -> AuxiliaryData:
117-
for t in [AlonzoMetadata, ShellayMarryMetadata, Metadata]:
119+
for t in [AlonzoMetadata, ShelleyMarryMetadata, Metadata]:
118120
# The schema of metadata in different eras are mutually exclusive, so we can try deserializing
119121
# them one by one without worrying about mismatch.
120122
try:
121-
return AuxiliaryData(t.from_primitive(value))
123+
return AuxiliaryData(t.from_primitive(value)) # type: ignore
122124
except DeserializeException:
123125
pass
124126
raise DeserializeException(f"Couldn't parse auxiliary data: {value}")
125127

126128
def hash(self) -> AuxiliaryDataHash:
127129
return AuxiliaryDataHash(
128-
blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder)
130+
blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder) # type: ignore
129131
)

pycardano/plutus.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CBORSerializable,
2222
DictCBORSerializable,
2323
IndefiniteList,
24+
Primitive,
2425
RawCBOR,
2526
default_encoder,
2627
limit_primitive_type,
@@ -39,6 +40,7 @@
3940
"PlutusV2Script",
4041
"RawPlutusData",
4142
"Redeemer",
43+
"ScriptType",
4244
"datum_hash",
4345
"plutus_script_hash",
4446
"script_hash",
@@ -471,7 +473,7 @@ def __post_init__(self):
471473
)
472474

473475
def to_shallow_primitive(self) -> CBORTag:
474-
primitives = super().to_shallow_primitive()
476+
primitives: Primitive = super().to_shallow_primitive()
475477
if primitives:
476478
primitives = IndefiniteList(primitives)
477479
tag = get_tag(self.CONSTR_ID)
@@ -544,7 +546,7 @@ def _dfs(obj):
544546
return json.dumps(_dfs(self), **kwargs)
545547

546548
@classmethod
547-
def from_dict(cls: PlutusData, data: dict) -> PlutusData:
549+
def from_dict(cls: Type[PlutusData], data: dict) -> PlutusData:
548550
"""Convert a dictionary to PlutusData
549551
550552
Args:
@@ -606,7 +608,7 @@ def _dfs(obj):
606608
return _dfs(data)
607609

608610
@classmethod
609-
def from_json(cls: PlutusData, data: str) -> PlutusData:
611+
def from_json(cls: Type[PlutusData], data: str) -> PlutusData:
610612
"""Restore a json encoded string to a PlutusData.
611613
612614
Args:
@@ -701,7 +703,7 @@ class Redeemer(ArrayCBORSerializable):
701703

702704
data: Any
703705

704-
ex_units: ExecutionUnits = None
706+
ex_units: Optional[ExecutionUnits] = None
705707

706708
@classmethod
707709
@limit_primitive_type(list)
@@ -729,13 +731,23 @@ def plutus_script_hash(
729731
return script_hash(script)
730732

731733

732-
def script_hash(
733-
script: Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]
734-
) -> ScriptHash:
734+
class PlutusV1Script(bytes):
735+
pass
736+
737+
738+
class PlutusV2Script(bytes):
739+
pass
740+
741+
742+
ScriptType = Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]
743+
"""Script type. A Union type that contains all valid script types."""
744+
745+
746+
def script_hash(script: ScriptType) -> ScriptHash:
735747
"""Calculates the hash of a script, which could be either native script or plutus script.
736748
737749
Args:
738-
script (Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]): A script.
750+
script (ScriptType): A script.
739751
740752
Returns:
741753
ScriptHash: blake2b hash of the script.
@@ -752,11 +764,3 @@ def script_hash(
752764
)
753765
else:
754766
raise TypeError(f"Unexpected script type: {type(script)}")
755-
756-
757-
class PlutusV1Script(bytes):
758-
pass
759-
760-
761-
class PlutusV2Script(bytes):
762-
pass

0 commit comments

Comments
 (0)