Skip to content

Improve Ogmios backend module #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ jobs:
- name: Install dependencies
run: |
poetry install
- name: Run static analyses
run: |
make qa
- name: Run unit tests
run: |
poetry run pytest --doctest-modules --ignore=examples --cov=pycardano --cov-config=.coveragerc --cov-report=xml
- name: "Upload coverage to Codecov"
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: true
- name: Run static analyses
run: |
make qa

continuous-integration:
runs-on: ${{ matrix.os }}
Expand Down
6 changes: 3 additions & 3 deletions pycardano/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

from enum import Enum
from typing import Union
from typing import Union, Type

from pycardano.crypto.bech32 import decode, encode
from pycardano.exception import (
Expand Down Expand Up @@ -160,7 +160,7 @@ def to_primitive(self) -> bytes:
return self.encode()

@classmethod
def from_primitive(cls, value: bytes) -> PointerAddress:
def from_primitive(cls: Type[PointerAddress], value: bytes) -> PointerAddress:
return cls.decode(value)

def __eq__(self, other):
Expand Down Expand Up @@ -339,7 +339,7 @@ def to_primitive(self) -> bytes:
return bytes(self)

@classmethod
def from_primitive(cls, value: Union[bytes, str]) -> Address:
def from_primitive(cls: Type[Address], value: Union[bytes, str]) -> Address:
if isinstance(value, str):
value = bytes(decode(value))
header = value[0]
Expand Down
59 changes: 36 additions & 23 deletions pycardano/backend/ogmios.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import calendar
import json
import time
from typing import Dict, List, Union
from enum import Enum
from typing import Any, Dict, List, Optional, Union, Tuple

import cbor2
import requests
Expand Down Expand Up @@ -32,7 +33,24 @@
__all__ = ["OgmiosChainContext"]


JSON = Dict[str, Any]


class OgmiosQueryType(str, Enum):
Query = "Query"
SubmitTx = "SubmitTx"
EvaluateTx = "EvaluateTx"


class OgmiosChainContext(ChainContext):
_ws_url: str
_network: Network
_service_name: str
_kupo_url: Optional[str]
_last_known_block_slot: int
_genesis_param: Optional[GenesisParameters]
_protocol_param: Optional[ProtocolParameters]

def __init__(
self,
ws_url: str,
Expand All @@ -48,15 +66,15 @@ def __init__(
self._genesis_param = None
self._protocol_param = None

def _request(self, method: str, args: dict) -> Union[dict, int]:
def _request(self, method: OgmiosQueryType, args: JSON) -> Any:
ws = websocket.WebSocket()
ws.connect(self._ws_url)
request = json.dumps(
{
"type": "jsonwsp/request",
"version": "1.0",
"servicename": self._service_name,
"methodname": method,
"methodname": method.value,
"args": args,
},
separators=(",", ":"),
Expand Down Expand Up @@ -86,10 +104,9 @@ def _fraction_parser(fraction: str) -> float:
@property
def protocol_param(self) -> ProtocolParameters:
"""Get current protocol parameters"""
method = "Query"
args = {"query": "currentProtocolParameters"}
if not self._protocol_param or self._check_chain_tip_and_update():
result = self._request(method, args)
result = self._request(OgmiosQueryType.Query, args)
param = ProtocolParameters(
min_fee_constant=result["minFeeConstant"],
min_fee_coefficient=result["minFeeCoefficient"],
Expand Down Expand Up @@ -130,7 +147,7 @@ def protocol_param(self) -> ProtocolParameters:
param.cost_models["PlutusV2"] = param.cost_models.pop("plutus:v2")

args = {"query": "genesisConfig"}
result = self._request(method, args)
result = self._request(OgmiosQueryType.Query, args)
param.min_utxo = result["protocolParameters"]["minUtxoValue"]

self._protocol_param = param
Expand All @@ -139,10 +156,9 @@ def protocol_param(self) -> ProtocolParameters:
@property
def genesis_param(self) -> GenesisParameters:
"""Get chain genesis parameters"""
method = "Query"
args = {"query": "genesisConfig"}
if not self._genesis_param or self._check_chain_tip_and_update():
result = self._request(method, args)
result = self._request(OgmiosQueryType.Query, args)
system_start_unix = int(
calendar.timegm(
time.strptime(
Expand Down Expand Up @@ -174,23 +190,21 @@ def network(self) -> Network:
@property
def epoch(self) -> int:
"""Current epoch number"""
method = "Query"
args = {"query": "currentEpoch"}
return self._request(method, args)
return self._request(OgmiosQueryType.Query, args)

@property
def last_block_slot(self) -> int:
"""Slot number of last block"""
method = "Query"
args = {"query": "chainTip"}
return self._request(method, args)["slot"]
return self._request(OgmiosQueryType.Query, args)["slot"]

def _extract_asset_info(self, asset_hash: str):
def _extract_asset_info(self, asset_hash: str) -> Tuple[str, ScriptHash, AssetName]:
policy_hex, asset_name_hex = asset_hash.split(".")
policy = ScriptHash.from_primitive(policy_hex)
asset_name_hex = AssetName.from_primitive(asset_name_hex)
asset_name = AssetName.from_primitive(asset_name_hex)

return policy_hex, policy, asset_name_hex
return policy_hex, policy, asset_name

def _check_utxo_unspent(self, tx_id: str, index: int) -> bool:
"""Check whether an UTxO is unspent with Ogmios.
Expand All @@ -200,9 +214,8 @@ def _check_utxo_unspent(self, tx_id: str, index: int) -> bool:
index (int): transaction index.
"""

method = "Query"
args = {"query": {"utxo": [{"txId": tx_id, "index": index}]}}
results = self._request(method, args)
results = self._request(OgmiosQueryType.Query, args)

if results:
return True
Expand All @@ -220,6 +233,9 @@ def _utxos_kupo(self, address: str) -> List[UTxO]:
Returns:
List[UTxO]: A list of UTxOs.
"""
if self._kupo_url is None:
raise AssertionError("kupo_url object attribute has not been assigned properly.")

address_url = self._kupo_url + "/" + address
results = requests.get(address_url).json()

Expand Down Expand Up @@ -282,9 +298,8 @@ def _utxos_ogmios(self, address: str) -> List[UTxO]:
List[UTxO]: A list of UTxOs.
"""

method = "Query"
args = {"query": {"utxo": [address]}}
results = self._request(method, args)
results = self._request(OgmiosQueryType.Query, args)

utxos = []

Expand Down Expand Up @@ -374,9 +389,8 @@ def submit_tx(self, cbor: Union[bytes, str]):
if isinstance(cbor, bytes):
cbor = cbor.hex()

method = "SubmitTx"
args = {"submit": cbor}
result = self._request(method, args)
result = self._request(OgmiosQueryType.SubmitTx, args)
if "SubmitFail" in result:
raise TransactionFailedException(result["SubmitFail"])

Expand All @@ -395,9 +409,8 @@ def evaluate_tx(self, cbor: Union[bytes, str]) -> Dict[str, ExecutionUnits]:
if isinstance(cbor, bytes):
cbor = cbor.hex()

method = "EvaluateTx"
args = {"evaluate": cbor}
result = self._request(method, args)
result = self._request(OgmiosQueryType.EvaluateTx, args)
if "EvaluationResult" not in result:
raise TransactionFailedException(result)
else:
Expand Down
4 changes: 2 additions & 2 deletions pycardano/hash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""All type of hashes in Cardano ledger spec."""

from typing import TypeVar, Union
from typing import TypeVar, Union, Type

from pycardano.serialization import CBORSerializable

Expand Down Expand Up @@ -67,7 +67,7 @@ def to_primitive(self) -> bytes:
return self.payload

@classmethod
def from_primitive(cls: T, value: Union[bytes, str]) -> T:
def from_primitive(cls: Type[T], value: Union[bytes, str]) -> T:
if isinstance(value, str):
value = bytes.fromhex(value)
return cls(value)
Expand Down
3 changes: 2 additions & 1 deletion pycardano/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import os
from typing import Type

from nacl.encoding import RawEncoder
from nacl.hash import blake2b
Expand Down Expand Up @@ -61,7 +62,7 @@ def to_primitive(self) -> bytes:
return self.payload

@classmethod
def from_primitive(cls, value: bytes) -> Key:
def from_primitive(cls: Type["Key"], value: bytes) -> Key:
return cls(value)

def to_json(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions pycardano/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, ClassVar, List, Union
from typing import Any, ClassVar, List, Union, Type

from cbor2 import CBORTag
from nacl.encoding import RawEncoder
Expand Down Expand Up @@ -91,7 +91,7 @@ def to_primitive(self) -> Primitive:
return CBORTag(AlonzoMetadata.TAG, super(AlonzoMetadata, self).to_primitive())

@classmethod
def from_primitive(cls: AlonzoMetadata, value: CBORTag) -> AlonzoMetadata:
def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:
if not hasattr(value, "tag"):
raise DeserializeException(
f"{value} does not match the data schema of AlonzoMetadata."
Expand All @@ -111,7 +111,7 @@ def to_primitive(self) -> Primitive:
return self.data.to_primitive()

@classmethod
def from_primitive(cls: AuxiliaryData, value: Primitive) -> AuxiliaryData:
def from_primitive(cls: Type[AuxiliaryData], value: Primitive) -> AuxiliaryData:
for t in [AlonzoMetadata, ShellayMarryMetadata, Metadata]:
# The schema of metadata in different eras are mutually exclusive, so we can try deserializing
# them one by one without worrying about mismatch.
Expand Down
4 changes: 2 additions & 2 deletions pycardano/nativescript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import ClassVar, List, Union
from typing import ClassVar, List, Union, Type

from nacl.encoding import RawEncoder
from nacl.hash import blake2b
Expand All @@ -27,7 +27,7 @@
class NativeScript(ArrayCBORSerializable):
@classmethod
def from_primitive(
cls: NativeScript, value: Primitive
cls: Type[NativeScript], value: Primitive
) -> Union[
ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter
]:
Expand Down
3 changes: 2 additions & 1 deletion pycardano/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from enum import Enum
from typing import Type

from pycardano.serialization import CBORSerializable

Expand All @@ -21,5 +22,5 @@ def to_primitive(self) -> int:
return self.value

@classmethod
def from_primitive(cls, value: int) -> Network:
def from_primitive(cls: Type[Network], value: int) -> Network:
return cls(value)
12 changes: 6 additions & 6 deletions pycardano/plutus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, ClassVar, List, Optional, Union
from typing import Any, ClassVar, List, Optional, Union, Type

import cbor2
from cbor2 import CBORTag
Expand Down Expand Up @@ -66,7 +66,7 @@ def to_shallow_primitive(self) -> dict:
return result

@classmethod
def from_primitive(cls: CostModels, value: dict) -> CostModels:
def from_primitive(cls: Type[CostModels], value: dict) -> CostModels:
raise DeserializeException(
"Deserialization of cost model is impossible, because some information is lost "
"during serialization."
Expand Down Expand Up @@ -480,7 +480,7 @@ def to_shallow_primitive(self) -> CBORTag:
return CBORTag(102, [self.CONSTR_ID, primitives])

@classmethod
def from_primitive(cls: PlutusData, value: CBORTag) -> PlutusData:
def from_primitive(cls: Type[PlutusData], value: CBORTag) -> PlutusData:
if not isinstance(value, CBORTag):
raise DeserializeException(
f"Unexpected type: {CBORTag}. Got {type(value)} instead."
Expand Down Expand Up @@ -643,7 +643,7 @@ def _dfs(obj):
return _dfs(self.data)

@classmethod
def from_primitive(cls: RawPlutusData, value: CBORTag) -> RawPlutusData:
def from_primitive(cls: Type[RawPlutusData], value: CBORTag) -> RawPlutusData:
return cls(value)


Expand Down Expand Up @@ -675,7 +675,7 @@ def to_primitive(self) -> int:
return self.value

@classmethod
def from_primitive(cls, value: int) -> RedeemerTag:
def from_primitive(cls: Type[RedeemerTag], value: int) -> RedeemerTag:
return cls(value)


Expand Down Expand Up @@ -704,7 +704,7 @@ class Redeemer(ArrayCBORSerializable):
ex_units: ExecutionUnits = None

@classmethod
def from_primitive(cls: Redeemer, values: List[Primitive]) -> Redeemer:
def from_primitive(cls: Type[Redeemer], values: List[Primitive]) -> Redeemer:
if isinstance(values[2], CBORTag) and cls is Redeemer:
values[2] = RawPlutusData.from_primitive(values[2])
redeemer = super(Redeemer, cls).from_primitive(
Expand Down
8 changes: 4 additions & 4 deletions pycardano/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def to_validated_primitive(self) -> Primitive:
return self.to_primitive()

@classmethod
def from_primitive(cls: CBORBase, value: Primitive) -> CBORBase:
def from_primitive(cls: Type[CBORBase], value: Primitive) -> CBORBase:
"""Turn a CBOR primitive to its original class type.

Args:
Expand Down Expand Up @@ -473,7 +473,7 @@ def to_shallow_primitive(self) -> List[Primitive]:
return primitives

@classmethod
def from_primitive(cls: ArrayBase, values: List[Primitive]) -> ArrayBase:
def from_primitive(cls: Type[ArrayBase], values: List[Primitive]) -> ArrayBase:
"""Restore a primitive value to its original class type.

Args:
Expand Down Expand Up @@ -585,7 +585,7 @@ def to_shallow_primitive(self) -> Primitive:
return primitives

@classmethod
def from_primitive(cls: MapBase, values: Primitive) -> MapBase:
def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
"""Restore a primitive value to its original class type.

Args:
Expand Down Expand Up @@ -704,7 +704,7 @@ def _get_sortable_val(key):
return dict(sorted(self.data.items(), key=lambda x: _get_sortable_val(x[0])))

@classmethod
def from_primitive(cls: DictBase, value: dict) -> DictBase:
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
"""Restore a primitive value to its original class type.

Args:
Expand Down
Loading