Skip to content

Commit c5430c0

Browse files
authored
Merge pull request #111 from koreapool/enhance-backend-package-type-hint
Improve Ogmios backend module
2 parents f403c32 + 8092dc5 commit c5430c0

File tree

12 files changed

+68
-54
lines changed

12 files changed

+68
-54
lines changed

.github/workflows/main.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
poetry install
35-
- name: Run static analyses
36-
run: |
37-
make qa
3835
- name: Run unit tests
3936
run: |
4037
poetry run pytest --doctest-modules --ignore=examples --cov=pycardano --cov-config=.coveragerc --cov-report=xml
4138
- name: "Upload coverage to Codecov"
4239
uses: codecov/codecov-action@v3
4340
with:
4441
fail_ci_if_error: true
42+
- name: Run static analyses
43+
run: |
44+
make qa
4545
4646
continuous-integration:
4747
runs-on: ${{ matrix.os }}

pycardano/address.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
from enum import Enum
12-
from typing import Union
12+
from typing import Union, Type
1313

1414
from pycardano.crypto.bech32 import decode, encode
1515
from pycardano.exception import (
@@ -160,7 +160,7 @@ def to_primitive(self) -> bytes:
160160
return self.encode()
161161

162162
@classmethod
163-
def from_primitive(cls, value: bytes) -> PointerAddress:
163+
def from_primitive(cls: Type[PointerAddress], value: bytes) -> PointerAddress:
164164
return cls.decode(value)
165165

166166
def __eq__(self, other):
@@ -339,7 +339,7 @@ def to_primitive(self) -> bytes:
339339
return bytes(self)
340340

341341
@classmethod
342-
def from_primitive(cls, value: Union[bytes, str]) -> Address:
342+
def from_primitive(cls: Type[Address], value: Union[bytes, str]) -> Address:
343343
if isinstance(value, str):
344344
value = bytes(decode(value))
345345
header = value[0]

pycardano/backend/ogmios.py

+36-23
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import calendar
22
import json
33
import time
4-
from typing import Dict, List, Union
4+
from enum import Enum
5+
from typing import Any, Dict, List, Optional, Union, Tuple
56

67
import cbor2
78
import requests
@@ -32,7 +33,24 @@
3233
__all__ = ["OgmiosChainContext"]
3334

3435

36+
JSON = Dict[str, Any]
37+
38+
39+
class OgmiosQueryType(str, Enum):
40+
Query = "Query"
41+
SubmitTx = "SubmitTx"
42+
EvaluateTx = "EvaluateTx"
43+
44+
3545
class OgmiosChainContext(ChainContext):
46+
_ws_url: str
47+
_network: Network
48+
_service_name: str
49+
_kupo_url: Optional[str]
50+
_last_known_block_slot: int
51+
_genesis_param: Optional[GenesisParameters]
52+
_protocol_param: Optional[ProtocolParameters]
53+
3654
def __init__(
3755
self,
3856
ws_url: str,
@@ -48,15 +66,15 @@ def __init__(
4866
self._genesis_param = None
4967
self._protocol_param = None
5068

51-
def _request(self, method: str, args: dict) -> Union[dict, int]:
69+
def _request(self, method: OgmiosQueryType, args: JSON) -> Any:
5270
ws = websocket.WebSocket()
5371
ws.connect(self._ws_url)
5472
request = json.dumps(
5573
{
5674
"type": "jsonwsp/request",
5775
"version": "1.0",
5876
"servicename": self._service_name,
59-
"methodname": method,
77+
"methodname": method.value,
6078
"args": args,
6179
},
6280
separators=(",", ":"),
@@ -86,10 +104,9 @@ def _fraction_parser(fraction: str) -> float:
86104
@property
87105
def protocol_param(self) -> ProtocolParameters:
88106
"""Get current protocol parameters"""
89-
method = "Query"
90107
args = {"query": "currentProtocolParameters"}
91108
if not self._protocol_param or self._check_chain_tip_and_update():
92-
result = self._request(method, args)
109+
result = self._request(OgmiosQueryType.Query, args)
93110
param = ProtocolParameters(
94111
min_fee_constant=result["minFeeConstant"],
95112
min_fee_coefficient=result["minFeeCoefficient"],
@@ -130,7 +147,7 @@ def protocol_param(self) -> ProtocolParameters:
130147
param.cost_models["PlutusV2"] = param.cost_models.pop("plutus:v2")
131148

132149
args = {"query": "genesisConfig"}
133-
result = self._request(method, args)
150+
result = self._request(OgmiosQueryType.Query, args)
134151
param.min_utxo = result["protocolParameters"]["minUtxoValue"]
135152

136153
self._protocol_param = param
@@ -139,10 +156,9 @@ def protocol_param(self) -> ProtocolParameters:
139156
@property
140157
def genesis_param(self) -> GenesisParameters:
141158
"""Get chain genesis parameters"""
142-
method = "Query"
143159
args = {"query": "genesisConfig"}
144160
if not self._genesis_param or self._check_chain_tip_and_update():
145-
result = self._request(method, args)
161+
result = self._request(OgmiosQueryType.Query, args)
146162
system_start_unix = int(
147163
calendar.timegm(
148164
time.strptime(
@@ -174,23 +190,21 @@ def network(self) -> Network:
174190
@property
175191
def epoch(self) -> int:
176192
"""Current epoch number"""
177-
method = "Query"
178193
args = {"query": "currentEpoch"}
179-
return self._request(method, args)
194+
return self._request(OgmiosQueryType.Query, args)
180195

181196
@property
182197
def last_block_slot(self) -> int:
183198
"""Slot number of last block"""
184-
method = "Query"
185199
args = {"query": "chainTip"}
186-
return self._request(method, args)["slot"]
200+
return self._request(OgmiosQueryType.Query, args)["slot"]
187201

188-
def _extract_asset_info(self, asset_hash: str):
202+
def _extract_asset_info(self, asset_hash: str) -> Tuple[str, ScriptHash, AssetName]:
189203
policy_hex, asset_name_hex = asset_hash.split(".")
190204
policy = ScriptHash.from_primitive(policy_hex)
191-
asset_name_hex = AssetName.from_primitive(asset_name_hex)
205+
asset_name = AssetName.from_primitive(asset_name_hex)
192206

193-
return policy_hex, policy, asset_name_hex
207+
return policy_hex, policy, asset_name
194208

195209
def _check_utxo_unspent(self, tx_id: str, index: int) -> bool:
196210
"""Check whether an UTxO is unspent with Ogmios.
@@ -200,9 +214,8 @@ def _check_utxo_unspent(self, tx_id: str, index: int) -> bool:
200214
index (int): transaction index.
201215
"""
202216

203-
method = "Query"
204217
args = {"query": {"utxo": [{"txId": tx_id, "index": index}]}}
205-
results = self._request(method, args)
218+
results = self._request(OgmiosQueryType.Query, args)
206219

207220
if results:
208221
return True
@@ -220,6 +233,9 @@ def _utxos_kupo(self, address: str) -> List[UTxO]:
220233
Returns:
221234
List[UTxO]: A list of UTxOs.
222235
"""
236+
if self._kupo_url is None:
237+
raise AssertionError("kupo_url object attribute has not been assigned properly.")
238+
223239
address_url = self._kupo_url + "/" + address
224240
results = requests.get(address_url).json()
225241

@@ -282,9 +298,8 @@ def _utxos_ogmios(self, address: str) -> List[UTxO]:
282298
List[UTxO]: A list of UTxOs.
283299
"""
284300

285-
method = "Query"
286301
args = {"query": {"utxo": [address]}}
287-
results = self._request(method, args)
302+
results = self._request(OgmiosQueryType.Query, args)
288303

289304
utxos = []
290305

@@ -374,9 +389,8 @@ def submit_tx(self, cbor: Union[bytes, str]):
374389
if isinstance(cbor, bytes):
375390
cbor = cbor.hex()
376391

377-
method = "SubmitTx"
378392
args = {"submit": cbor}
379-
result = self._request(method, args)
393+
result = self._request(OgmiosQueryType.SubmitTx, args)
380394
if "SubmitFail" in result:
381395
raise TransactionFailedException(result["SubmitFail"])
382396

@@ -395,9 +409,8 @@ def evaluate_tx(self, cbor: Union[bytes, str]) -> Dict[str, ExecutionUnits]:
395409
if isinstance(cbor, bytes):
396410
cbor = cbor.hex()
397411

398-
method = "EvaluateTx"
399412
args = {"evaluate": cbor}
400-
result = self._request(method, args)
413+
result = self._request(OgmiosQueryType.EvaluateTx, args)
401414
if "EvaluationResult" not in result:
402415
raise TransactionFailedException(result)
403416
else:

pycardano/hash.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""All type of hashes in Cardano ledger spec."""
22

3-
from typing import TypeVar, Union
3+
from typing import TypeVar, Union, Type
44

55
from pycardano.serialization import CBORSerializable
66

@@ -67,7 +67,7 @@ def to_primitive(self) -> bytes:
6767
return self.payload
6868

6969
@classmethod
70-
def from_primitive(cls: T, value: Union[bytes, str]) -> T:
70+
def from_primitive(cls: Type[T], value: Union[bytes, str]) -> T:
7171
if isinstance(value, str):
7272
value = bytes.fromhex(value)
7373
return cls(value)

pycardano/key.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
from typing import Type
78

89
from nacl.encoding import RawEncoder
910
from nacl.hash import blake2b
@@ -61,7 +62,7 @@ def to_primitive(self) -> bytes:
6162
return self.payload
6263

6364
@classmethod
64-
def from_primitive(cls, value: bytes) -> Key:
65+
def from_primitive(cls: Type["Key"], value: bytes) -> Key:
6566
return cls(value)
6667

6768
def to_json(self) -> str:

pycardano/metadata.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from typing import Any, ClassVar, List, Union
4+
from typing import Any, ClassVar, List, Union, Type
55

66
from cbor2 import CBORTag
77
from nacl.encoding import RawEncoder
@@ -91,7 +91,7 @@ def to_primitive(self) -> Primitive:
9191
return CBORTag(AlonzoMetadata.TAG, super(AlonzoMetadata, self).to_primitive())
9292

9393
@classmethod
94-
def from_primitive(cls: AlonzoMetadata, value: CBORTag) -> AlonzoMetadata:
94+
def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:
9595
if not hasattr(value, "tag"):
9696
raise DeserializeException(
9797
f"{value} does not match the data schema of AlonzoMetadata."
@@ -111,7 +111,7 @@ def to_primitive(self) -> Primitive:
111111
return self.data.to_primitive()
112112

113113
@classmethod
114-
def from_primitive(cls: AuxiliaryData, value: Primitive) -> AuxiliaryData:
114+
def from_primitive(cls: Type[AuxiliaryData], value: Primitive) -> AuxiliaryData:
115115
for t in [AlonzoMetadata, ShellayMarryMetadata, Metadata]:
116116
# The schema of metadata in different eras are mutually exclusive, so we can try deserializing
117117
# them one by one without worrying about mismatch.

pycardano/nativescript.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field
6-
from typing import ClassVar, List, Union
6+
from typing import ClassVar, List, Union, Type
77

88
from nacl.encoding import RawEncoder
99
from nacl.hash import blake2b
@@ -27,7 +27,7 @@
2727
class NativeScript(ArrayCBORSerializable):
2828
@classmethod
2929
def from_primitive(
30-
cls: NativeScript, value: Primitive
30+
cls: Type[NativeScript], value: Primitive
3131
) -> Union[
3232
ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter
3333
]:

pycardano/network.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from enum import Enum
6+
from typing import Type
67

78
from pycardano.serialization import CBORSerializable
89

@@ -21,5 +22,5 @@ def to_primitive(self) -> int:
2122
return self.value
2223

2324
@classmethod
24-
def from_primitive(cls, value: int) -> Network:
25+
def from_primitive(cls: Type[Network], value: int) -> Network:
2526
return cls(value)

pycardano/plutus.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
from dataclasses import dataclass, field, fields
88
from enum import Enum
9-
from typing import Any, ClassVar, List, Optional, Union
9+
from typing import Any, ClassVar, List, Optional, Union, Type
1010

1111
import cbor2
1212
from cbor2 import CBORTag
@@ -66,7 +66,7 @@ def to_shallow_primitive(self) -> dict:
6666
return result
6767

6868
@classmethod
69-
def from_primitive(cls: CostModels, value: dict) -> CostModels:
69+
def from_primitive(cls: Type[CostModels], value: dict) -> CostModels:
7070
raise DeserializeException(
7171
"Deserialization of cost model is impossible, because some information is lost "
7272
"during serialization."
@@ -480,7 +480,7 @@ def to_shallow_primitive(self) -> CBORTag:
480480
return CBORTag(102, [self.CONSTR_ID, primitives])
481481

482482
@classmethod
483-
def from_primitive(cls: PlutusData, value: CBORTag) -> PlutusData:
483+
def from_primitive(cls: Type[PlutusData], value: CBORTag) -> PlutusData:
484484
if not isinstance(value, CBORTag):
485485
raise DeserializeException(
486486
f"Unexpected type: {CBORTag}. Got {type(value)} instead."
@@ -643,7 +643,7 @@ def _dfs(obj):
643643
return _dfs(self.data)
644644

645645
@classmethod
646-
def from_primitive(cls: RawPlutusData, value: CBORTag) -> RawPlutusData:
646+
def from_primitive(cls: Type[RawPlutusData], value: CBORTag) -> RawPlutusData:
647647
return cls(value)
648648

649649

@@ -675,7 +675,7 @@ def to_primitive(self) -> int:
675675
return self.value
676676

677677
@classmethod
678-
def from_primitive(cls, value: int) -> RedeemerTag:
678+
def from_primitive(cls: Type[RedeemerTag], value: int) -> RedeemerTag:
679679
return cls(value)
680680

681681

@@ -704,7 +704,7 @@ class Redeemer(ArrayCBORSerializable):
704704
ex_units: ExecutionUnits = None
705705

706706
@classmethod
707-
def from_primitive(cls: Redeemer, values: List[Primitive]) -> Redeemer:
707+
def from_primitive(cls: Type[Redeemer], values: List[Primitive]) -> Redeemer:
708708
if isinstance(values[2], CBORTag) and cls is Redeemer:
709709
values[2] = RawPlutusData.from_primitive(values[2])
710710
redeemer = super(Redeemer, cls).from_primitive(

pycardano/serialization.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def to_validated_primitive(self) -> Primitive:
222222
return self.to_primitive()
223223

224224
@classmethod
225-
def from_primitive(cls: CBORBase, value: Primitive) -> CBORBase:
225+
def from_primitive(cls: Type[CBORBase], value: Primitive) -> CBORBase:
226226
"""Turn a CBOR primitive to its original class type.
227227
228228
Args:
@@ -473,7 +473,7 @@ def to_shallow_primitive(self) -> List[Primitive]:
473473
return primitives
474474

475475
@classmethod
476-
def from_primitive(cls: ArrayBase, values: List[Primitive]) -> ArrayBase:
476+
def from_primitive(cls: Type[ArrayBase], values: List[Primitive]) -> ArrayBase:
477477
"""Restore a primitive value to its original class type.
478478
479479
Args:
@@ -585,7 +585,7 @@ def to_shallow_primitive(self) -> Primitive:
585585
return primitives
586586

587587
@classmethod
588-
def from_primitive(cls: MapBase, values: Primitive) -> MapBase:
588+
def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
589589
"""Restore a primitive value to its original class type.
590590
591591
Args:
@@ -704,7 +704,7 @@ def _get_sortable_val(key):
704704
return dict(sorted(self.data.items(), key=lambda x: _get_sortable_val(x[0])))
705705

706706
@classmethod
707-
def from_primitive(cls: DictBase, value: dict) -> DictBase:
707+
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
708708
"""Restore a primitive value to its original class type.
709709
710710
Args:

0 commit comments

Comments
 (0)