Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions src/lean_spec/subspecs/xmss/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

from typing import List, Tuple

from pydantic import model_validator

from lean_spec.subspecs.xmss.target_sum import (
PROD_TARGET_SUM_ENCODER,
TEST_TARGET_SUM_ENCODER,
TargetSumEncoder,
)
from lean_spec.types.uint import Uint64
from lean_spec.types import StrictBaseModel, Uint64

from .constants import (
PROD_CONFIG,
Expand All @@ -29,39 +31,60 @@
MerkleTree,
)
from .prf import PROD_PRF, TEST_PRF, Prf
from .rand import PROD_RAND, TEST_RAND, Rand
from .tweak_hash import (
PROD_TWEAK_HASHER,
TEST_TWEAK_HASHER,
TweakHasher,
)
from .utils import (
PROD_RAND,
TEST_RAND,
Rand,
bottom_tree_from_prf_key,
expand_activation_time,
)


class GeneralizedXmssScheme:
"""Instance of the Generalized XMSS signature scheme for a given config."""

def __init__(
self,
config: XmssConfig,
prf: Prf,
hasher: TweakHasher,
merkle_tree: MerkleTree,
encoder: TargetSumEncoder,
rand: Rand,
):
"""Initializes the scheme with a specific parameter set."""
self.config = config
self.prf = prf
self.hasher = hasher
self.merkle_tree = merkle_tree
self.encoder = encoder
self.rand = rand
from .utils import bottom_tree_from_prf_key, expand_activation_time


class GeneralizedXmssScheme(StrictBaseModel):
"""
Instance of the Generalized XMSS signature scheme for a given config.

This class enforces strict type checking to ensure only approved component
implementations are used. Subclasses of the base component types (such as
SeededPrf or SeededRand) are explicitly rejected.
"""

config: XmssConfig
"""Configuration parameters for the XMSS scheme."""

prf: Prf
"""Pseudorandom function for deriving secret values."""

hasher: TweakHasher
"""Hash function with tweakable domain separation."""

merkle_tree: MerkleTree
"""Merkle tree implementation for authentication paths."""

encoder: TargetSumEncoder
"""Message encoder that produces valid codewords."""

rand: Rand
"""Random data generator for key generation."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "GeneralizedXmssScheme":
"""Validates that only exact approved types are used (rejects subclasses)."""
checks = {
"config": XmssConfig,
"prf": Prf,
"hasher": TweakHasher,
"merkle_tree": MerkleTree,
"encoder": TargetSumEncoder,
"rand": Rand,
}
for field, expected in checks.items():
if type(getattr(self, field)) is not expected:
raise TypeError(
f"{field} must be exactly {expected.__name__}, "
f"got {type(getattr(self, field)).__name__}"
)
return self

def key_gen(
self, activation_epoch: Uint64, num_active_epochs: Uint64
Expand Down Expand Up @@ -571,21 +594,21 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:


PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme(
PROD_CONFIG,
PROD_PRF,
PROD_TWEAK_HASHER,
PROD_MERKLE_TREE,
PROD_TARGET_SUM_ENCODER,
PROD_RAND,
config=PROD_CONFIG,
prf=PROD_PRF,
hasher=PROD_TWEAK_HASHER,
merkle_tree=PROD_MERKLE_TREE,
encoder=PROD_TARGET_SUM_ENCODER,
rand=PROD_RAND,
)
"""An instance configured for production-level parameters."""

TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme(
TEST_CONFIG,
TEST_PRF,
TEST_TWEAK_HASHER,
TEST_MERKLE_TREE,
TEST_TARGET_SUM_ENCODER,
TEST_RAND,
config=TEST_CONFIG,
prf=TEST_PRF,
hasher=TEST_TWEAK_HASHER,
merkle_tree=TEST_MERKLE_TREE,
encoder=TEST_TARGET_SUM_ENCODER,
rand=TEST_RAND,
)
"""A lightweight instance for test environments."""
37 changes: 27 additions & 10 deletions src/lean_spec/subspecs/xmss/merkle_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

from typing import List

from lean_spec.types import Uint64
from pydantic import model_validator

from lean_spec.types import StrictBaseModel, Uint64

from .constants import (
PROD_CONFIG,
Expand All @@ -46,23 +48,38 @@
HashTreeOpening,
Parameter,
)
from .rand import PROD_RAND, TEST_RAND, Rand
from .tweak_hash import (
PROD_TWEAK_HASHER,
TEST_TWEAK_HASHER,
TreeTweak,
TweakHasher,
)
from .utils import PROD_RAND, TEST_RAND, Rand


class MerkleTree:
class MerkleTree(StrictBaseModel):
"""An instance of the Merkle Tree handler for a given config."""

def __init__(self, config: XmssConfig, hasher: TweakHasher, rand: Rand):
"""Initializes with a config, a hasher, and a random generator."""
self.config = config
self.hasher = hasher
self.rand = rand
config: XmssConfig
"""Configuration parameters for the Merkle tree."""

hasher: TweakHasher
"""Hash function for hashing tree nodes."""

rand: Rand
"""Random generator for padding."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "MerkleTree":
"""Validates that only exact approved types are used (rejects subclasses)."""
checks = {"config": XmssConfig, "hasher": TweakHasher, "rand": Rand}
for field, expected in checks.items():
if type(getattr(self, field)) is not expected:
raise TypeError(
f"{field} must be exactly {expected.__name__}, "
f"got {type(getattr(self, field)).__name__}"
)
return self

def _get_padded_layer(self, nodes: List[HashDigest], start_index: int) -> HashTreeLayer:
"""
Expand Down Expand Up @@ -320,8 +337,8 @@ def verify_path(
return current_node == root


PROD_MERKLE_TREE = MerkleTree(PROD_CONFIG, PROD_TWEAK_HASHER, PROD_RAND)
PROD_MERKLE_TREE = MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND)
"""An instance configured for production-level parameters."""

TEST_MERKLE_TREE = MerkleTree(TEST_CONFIG, TEST_TWEAK_HASHER, TEST_RAND)
TEST_MERKLE_TREE = MerkleTree(config=TEST_CONFIG, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND)
"""A lightweight instance for test environments."""
31 changes: 23 additions & 8 deletions src/lean_spec/subspecs/xmss/message_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@

from typing import List

from pydantic import model_validator

from lean_spec.subspecs.xmss.poseidon import (
PROD_POSEIDON,
TEST_POSEIDON,
PoseidonXmss,
)
from lean_spec.types import Uint64
from lean_spec.types import StrictBaseModel, Uint64

from ..koalabear import Fp, P
from .constants import (
Expand All @@ -54,13 +56,26 @@
from .utils import int_to_base_p


class MessageHasher:
class MessageHasher(StrictBaseModel):
"""An instance of the "Top Level" message hasher for a given config."""

def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss):
"""Initializes the hasher with a specific parameter set."""
self.config = config
self.poseidon = poseidon_hasher
config: XmssConfig
"""Configuration parameters for the hasher."""

poseidon: PoseidonXmss
"""Poseidon hash engine."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "MessageHasher":
"""Validates that only exact approved types are used (rejects subclasses)."""
checks = {"config": XmssConfig, "poseidon": PoseidonXmss}
for field, expected in checks.items():
if type(getattr(self, field)) is not expected:
raise TypeError(
f"{field} must be exactly {expected.__name__}, "
f"got {type(getattr(self, field)).__name__}"
)
return self

def encode_message(self, message: bytes) -> List[Fp]:
"""
Expand Down Expand Up @@ -187,8 +202,8 @@ def apply(
return self._map_into_hypercube_part(poseidon_outputs)


PROD_MESSAGE_HASHER = MessageHasher(PROD_CONFIG, PROD_POSEIDON)
PROD_MESSAGE_HASHER = MessageHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON)
"""An instance configured for production-level parameters."""

TEST_MESSAGE_HASHER = MessageHasher(TEST_CONFIG, TEST_POSEIDON)
TEST_MESSAGE_HASHER = MessageHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON)
"""A lightweight instance for test environments."""
35 changes: 26 additions & 9 deletions src/lean_spec/subspecs/xmss/poseidon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@

from typing import List

from pydantic import model_validator

from lean_spec.types import StrictBaseModel

from ..koalabear import Fp
from ..poseidon2.permutation import (
PARAMS_16,
Expand All @@ -35,13 +39,26 @@
from .utils import int_to_base_p


class PoseidonXmss:
class PoseidonXmss(StrictBaseModel):
"""An instance of the Poseidon2 hash engine for the XMSS scheme."""

def __init__(self, params16: Poseidon2Params, params24: Poseidon2Params):
"""Initializes the hasher with specific Poseidon2 permutations."""
self.params16 = params16
self.params24 = params24
params16: Poseidon2Params
"""Poseidon2 parameters for 16-width permutation."""

params24: Poseidon2Params
"""Poseidon2 parameters for 24-width permutation."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "PoseidonXmss":
"""Validates that only exact approved types are used (rejects subclasses)."""
checks = {"params16": Poseidon2Params, "params24": Poseidon2Params}
for field, expected in checks.items():
if type(getattr(self, field)) is not expected:
raise TypeError(
f"{field} must be exactly {expected.__name__}, "
f"got {type(getattr(self, field)).__name__}"
)
return self

def compress(self, input_vec: List[Fp], width: int, output_len: int) -> HashDigest:
"""
Expand Down Expand Up @@ -198,8 +215,8 @@ def sponge(
return output[:output_len]


# An instance configured for production-level parameters.
PROD_POSEIDON: PoseidonXmss = PoseidonXmss(PARAMS_16, PARAMS_24)
PROD_POSEIDON = PoseidonXmss(params16=PARAMS_16, params24=PARAMS_24)
"""An instance configured for production-level parameters."""

# A lightweight instance for test environments.
TEST_POSEIDON: PoseidonXmss = PoseidonXmss(PARAMS_16, PARAMS_24)
TEST_POSEIDON = PoseidonXmss(params16=PARAMS_16, params24=PARAMS_24)
"""A lightweight instance for test environments."""
22 changes: 15 additions & 7 deletions src/lean_spec/subspecs/xmss/prf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import os
from typing import List

from pydantic import model_validator

from lean_spec.subspecs.koalabear import Fp
from lean_spec.types.uint import Uint64
from lean_spec.types import StrictBaseModel, Uint64

from .constants import (
PRF_KEY_LENGTH,
Expand Down Expand Up @@ -79,12 +81,18 @@
"""


class Prf:
class Prf(StrictBaseModel):
"""An instance of the SHAKE128-based PRF for a given config."""

def __init__(self, config: XmssConfig):
"""Initializes the PRF with a specific parameter set."""
self.config = config
config: XmssConfig
"""Configuration parameters for the PRF."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "Prf":
"""Validates that only exact approved types are used (rejects subclasses)."""
if type(self.config) is not XmssConfig:
raise TypeError(f"config must be exactly XmssConfig, got {type(self.config).__name__}")
return self

def key_gen(self) -> PRFKey:
"""
Expand Down Expand Up @@ -228,8 +236,8 @@ def get_randomness(
]


PROD_PRF = Prf(PROD_CONFIG)
PROD_PRF = Prf(config=PROD_CONFIG)
"""An instance configured for production-level parameters."""

TEST_PRF = Prf(TEST_CONFIG)
TEST_PRF = Prf(config=TEST_CONFIG)
"""A lightweight instance for test environments."""
Loading
Loading