diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index bc500444..cf7986d4 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -10,12 +10,14 @@ from typing import List, Tuple +from pydantic import model_validator + from lean_spec.subspecs.xmss.target_sum import ( PROD_TARGET_SUM_ENCODER, TEST_TARGET_SUM_ENCODER, TargetSumEncoder, ) -from lean_spec.types.uint import Uint64 +from lean_spec.types import StrictBaseModel, Uint64 from .constants import ( PROD_CONFIG, @@ -29,39 +31,60 @@ MerkleTree, ) from .prf import PROD_PRF, TEST_PRF, Prf +from .rand import PROD_RAND, TEST_RAND, Rand from .tweak_hash import ( PROD_TWEAK_HASHER, TEST_TWEAK_HASHER, TweakHasher, ) -from .utils import ( - PROD_RAND, - TEST_RAND, - Rand, - bottom_tree_from_prf_key, - expand_activation_time, -) - - -class GeneralizedXmssScheme: - """Instance of the Generalized XMSS signature scheme for a given config.""" - - def __init__( - self, - config: XmssConfig, - prf: Prf, - hasher: TweakHasher, - merkle_tree: MerkleTree, - encoder: TargetSumEncoder, - rand: Rand, - ): - """Initializes the scheme with a specific parameter set.""" - self.config = config - self.prf = prf - self.hasher = hasher - self.merkle_tree = merkle_tree - self.encoder = encoder - self.rand = rand +from .utils import bottom_tree_from_prf_key, expand_activation_time + + +class GeneralizedXmssScheme(StrictBaseModel): + """ + Instance of the Generalized XMSS signature scheme for a given config. + + This class enforces strict type checking to ensure only approved component + implementations are used. Subclasses of the base component types (such as + SeededPrf or SeededRand) are explicitly rejected. + """ + + config: XmssConfig + """Configuration parameters for the XMSS scheme.""" + + prf: Prf + """Pseudorandom function for deriving secret values.""" + + hasher: TweakHasher + """Hash function with tweakable domain separation.""" + + merkle_tree: MerkleTree + """Merkle tree implementation for authentication paths.""" + + encoder: TargetSumEncoder + """Message encoder that produces valid codewords.""" + + rand: Rand + """Random data generator for key generation.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "GeneralizedXmssScheme": + """Validates that only exact approved types are used (rejects subclasses).""" + checks = { + "config": XmssConfig, + "prf": Prf, + "hasher": TweakHasher, + "merkle_tree": MerkleTree, + "encoder": TargetSumEncoder, + "rand": Rand, + } + for field, expected in checks.items(): + if type(getattr(self, field)) is not expected: + raise TypeError( + f"{field} must be exactly {expected.__name__}, " + f"got {type(getattr(self, field)).__name__}" + ) + return self def key_gen( self, activation_epoch: Uint64, num_active_epochs: Uint64 @@ -571,21 +594,21 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme( - PROD_CONFIG, - PROD_PRF, - PROD_TWEAK_HASHER, - PROD_MERKLE_TREE, - PROD_TARGET_SUM_ENCODER, - PROD_RAND, + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, ) """An instance configured for production-level parameters.""" TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme( - TEST_CONFIG, - TEST_PRF, - TEST_TWEAK_HASHER, - TEST_MERKLE_TREE, - TEST_TARGET_SUM_ENCODER, - TEST_RAND, + config=TEST_CONFIG, + prf=TEST_PRF, + hasher=TEST_TWEAK_HASHER, + merkle_tree=TEST_MERKLE_TREE, + encoder=TEST_TARGET_SUM_ENCODER, + rand=TEST_RAND, ) """A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py index c7b7fc0c..336bf4b6 100644 --- a/src/lean_spec/subspecs/xmss/merkle_tree.py +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -32,7 +32,9 @@ from typing import List -from lean_spec.types import Uint64 +from pydantic import model_validator + +from lean_spec.types import StrictBaseModel, Uint64 from .constants import ( PROD_CONFIG, @@ -46,23 +48,38 @@ HashTreeOpening, Parameter, ) +from .rand import PROD_RAND, TEST_RAND, Rand from .tweak_hash import ( PROD_TWEAK_HASHER, TEST_TWEAK_HASHER, TreeTweak, TweakHasher, ) -from .utils import PROD_RAND, TEST_RAND, Rand -class MerkleTree: +class MerkleTree(StrictBaseModel): """An instance of the Merkle Tree handler for a given config.""" - def __init__(self, config: XmssConfig, hasher: TweakHasher, rand: Rand): - """Initializes with a config, a hasher, and a random generator.""" - self.config = config - self.hasher = hasher - self.rand = rand + config: XmssConfig + """Configuration parameters for the Merkle tree.""" + + hasher: TweakHasher + """Hash function for hashing tree nodes.""" + + rand: Rand + """Random generator for padding.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "MerkleTree": + """Validates that only exact approved types are used (rejects subclasses).""" + checks = {"config": XmssConfig, "hasher": TweakHasher, "rand": Rand} + for field, expected in checks.items(): + if type(getattr(self, field)) is not expected: + raise TypeError( + f"{field} must be exactly {expected.__name__}, " + f"got {type(getattr(self, field)).__name__}" + ) + return self def _get_padded_layer(self, nodes: List[HashDigest], start_index: int) -> HashTreeLayer: """ @@ -320,8 +337,8 @@ def verify_path( return current_node == root -PROD_MERKLE_TREE = MerkleTree(PROD_CONFIG, PROD_TWEAK_HASHER, PROD_RAND) +PROD_MERKLE_TREE = MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) """An instance configured for production-level parameters.""" -TEST_MERKLE_TREE = MerkleTree(TEST_CONFIG, TEST_TWEAK_HASHER, TEST_RAND) +TEST_MERKLE_TREE = MerkleTree(config=TEST_CONFIG, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND) """A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index 154edf99..4b16a069 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -31,12 +31,14 @@ from typing import List +from pydantic import model_validator + from lean_spec.subspecs.xmss.poseidon import ( PROD_POSEIDON, TEST_POSEIDON, PoseidonXmss, ) -from lean_spec.types import Uint64 +from lean_spec.types import StrictBaseModel, Uint64 from ..koalabear import Fp, P from .constants import ( @@ -54,13 +56,26 @@ from .utils import int_to_base_p -class MessageHasher: +class MessageHasher(StrictBaseModel): """An instance of the "Top Level" message hasher for a given config.""" - def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss): - """Initializes the hasher with a specific parameter set.""" - self.config = config - self.poseidon = poseidon_hasher + config: XmssConfig + """Configuration parameters for the hasher.""" + + poseidon: PoseidonXmss + """Poseidon hash engine.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "MessageHasher": + """Validates that only exact approved types are used (rejects subclasses).""" + checks = {"config": XmssConfig, "poseidon": PoseidonXmss} + for field, expected in checks.items(): + if type(getattr(self, field)) is not expected: + raise TypeError( + f"{field} must be exactly {expected.__name__}, " + f"got {type(getattr(self, field)).__name__}" + ) + return self def encode_message(self, message: bytes) -> List[Fp]: """ @@ -187,8 +202,8 @@ def apply( return self._map_into_hypercube_part(poseidon_outputs) -PROD_MESSAGE_HASHER = MessageHasher(PROD_CONFIG, PROD_POSEIDON) +PROD_MESSAGE_HASHER = MessageHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) """An instance configured for production-level parameters.""" -TEST_MESSAGE_HASHER = MessageHasher(TEST_CONFIG, TEST_POSEIDON) +TEST_MESSAGE_HASHER = MessageHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON) """A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index e91ee4ce..86983622 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -24,6 +24,10 @@ from typing import List +from pydantic import model_validator + +from lean_spec.types import StrictBaseModel + from ..koalabear import Fp from ..poseidon2.permutation import ( PARAMS_16, @@ -35,13 +39,26 @@ from .utils import int_to_base_p -class PoseidonXmss: +class PoseidonXmss(StrictBaseModel): """An instance of the Poseidon2 hash engine for the XMSS scheme.""" - def __init__(self, params16: Poseidon2Params, params24: Poseidon2Params): - """Initializes the hasher with specific Poseidon2 permutations.""" - self.params16 = params16 - self.params24 = params24 + params16: Poseidon2Params + """Poseidon2 parameters for 16-width permutation.""" + + params24: Poseidon2Params + """Poseidon2 parameters for 24-width permutation.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "PoseidonXmss": + """Validates that only exact approved types are used (rejects subclasses).""" + checks = {"params16": Poseidon2Params, "params24": Poseidon2Params} + for field, expected in checks.items(): + if type(getattr(self, field)) is not expected: + raise TypeError( + f"{field} must be exactly {expected.__name__}, " + f"got {type(getattr(self, field)).__name__}" + ) + return self def compress(self, input_vec: List[Fp], width: int, output_len: int) -> HashDigest: """ @@ -198,8 +215,8 @@ def sponge( return output[:output_len] -# An instance configured for production-level parameters. -PROD_POSEIDON: PoseidonXmss = PoseidonXmss(PARAMS_16, PARAMS_24) +PROD_POSEIDON = PoseidonXmss(params16=PARAMS_16, params24=PARAMS_24) +"""An instance configured for production-level parameters.""" -# A lightweight instance for test environments. -TEST_POSEIDON: PoseidonXmss = PoseidonXmss(PARAMS_16, PARAMS_24) +TEST_POSEIDON = PoseidonXmss(params16=PARAMS_16, params24=PARAMS_24) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 714fd424..211f80c5 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -13,8 +13,10 @@ import os from typing import List +from pydantic import model_validator + from lean_spec.subspecs.koalabear import Fp -from lean_spec.types.uint import Uint64 +from lean_spec.types import StrictBaseModel, Uint64 from .constants import ( PRF_KEY_LENGTH, @@ -79,12 +81,18 @@ """ -class Prf: +class Prf(StrictBaseModel): """An instance of the SHAKE128-based PRF for a given config.""" - def __init__(self, config: XmssConfig): - """Initializes the PRF with a specific parameter set.""" - self.config = config + config: XmssConfig + """Configuration parameters for the PRF.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "Prf": + """Validates that only exact approved types are used (rejects subclasses).""" + if type(self.config) is not XmssConfig: + raise TypeError(f"config must be exactly XmssConfig, got {type(self.config).__name__}") + return self def key_gen(self) -> PRFKey: """ @@ -228,8 +236,8 @@ def get_randomness( ] -PROD_PRF = Prf(PROD_CONFIG) +PROD_PRF = Prf(config=PROD_CONFIG) """An instance configured for production-level parameters.""" -TEST_PRF = Prf(TEST_CONFIG) +TEST_PRF = Prf(config=TEST_CONFIG) """A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/rand.py b/src/lean_spec/subspecs/xmss/rand.py new file mode 100644 index 00000000..175579a9 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/rand.py @@ -0,0 +1,50 @@ +"""Random data generator for the XMSS signature scheme.""" + +import secrets +from typing import List + +from pydantic import model_validator + +from lean_spec.types import StrictBaseModel + +from ..koalabear import Fp, P +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from .containers import HashDigest, Parameter, Randomness + + +class Rand(StrictBaseModel): + """An instance of the random data generator for a given config.""" + + config: XmssConfig + """Configuration parameters for the random generator.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "Rand": + """Validates that only exact approved types are used (rejects subclasses).""" + if type(self.config) is not XmssConfig: + raise TypeError(f"config must be exactly XmssConfig, got {type(self.config).__name__}") + return self + + def field_elements(self, length: int) -> List[Fp]: + """Generates a random list of field elements.""" + # For each element, generate a secure random integer in the range [0, P-1]. + return [Fp(value=secrets.randbelow(P)) for _ in range(length)] + + def parameter(self) -> Parameter: + """Generates a random public parameter.""" + return self.field_elements(self.config.PARAMETER_LEN) + + def domain(self) -> HashDigest: + """Generates a random hash digest.""" + return self.field_elements(self.config.HASH_LEN_FE) + + def rho(self) -> Randomness: + """Generates randomness `rho` for message encoding.""" + return self.field_elements(self.config.RAND_LEN_FE) + + +PROD_RAND = Rand(config=PROD_CONFIG) +"""An instance configured for production-level parameters.""" + +TEST_RAND = Rand(config=TEST_CONFIG) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index 41e4bc31..65dc94f4 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -15,8 +15,8 @@ from .tweak_hash import TreeTweak if TYPE_CHECKING: + from .rand import Rand from .tweak_hash import TweakHasher - from .utils import Rand def _get_padded_layer(rand: Rand, nodes: List[HashDigest], start_index: int) -> HashTreeLayer: diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py index 8239b71d..886b51b8 100644 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -8,7 +8,9 @@ from typing import List, Optional -from lean_spec.types import Uint64 +from pydantic import model_validator + +from lean_spec.types import StrictBaseModel, Uint64 from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .containers import Parameter, Randomness @@ -19,7 +21,7 @@ ) -class TargetSumEncoder: +class TargetSumEncoder(StrictBaseModel): """ An instance of the Target Sum encoder for a given configuration. @@ -27,10 +29,23 @@ class TargetSumEncoder: scheme's target sum constraint. """ - def __init__(self, config: XmssConfig, message_hasher: MessageHasher): - """Initializes the encoder with a specific parameter set.""" - self.config = config - self.message_hasher = message_hasher + config: XmssConfig + """Configuration parameters for the encoder.""" + + message_hasher: MessageHasher + """Message hasher for encoding.""" + + @model_validator(mode="after") + def enforce_strict_types(self) -> "TargetSumEncoder": + """Validates that only exact approved types are used (rejects subclasses).""" + checks = {"config": XmssConfig, "message_hasher": MessageHasher} + for field, expected in checks.items(): + if type(getattr(self, field)) is not expected: + raise TypeError( + f"{field} must be exactly {expected.__name__}, " + f"got {type(getattr(self, field)).__name__}" + ) + return self def encode( self, parameter: Parameter, message: bytes, rho: Randomness, epoch: Uint64 @@ -80,8 +95,8 @@ def encode( return None -PROD_TARGET_SUM_ENCODER = TargetSumEncoder(PROD_CONFIG, PROD_MESSAGE_HASHER) +PROD_TARGET_SUM_ENCODER = TargetSumEncoder(config=PROD_CONFIG, message_hasher=PROD_MESSAGE_HASHER) """An instance configured for production-level parameters.""" -TEST_TARGET_SUM_ENCODER = TargetSumEncoder(TEST_CONFIG, TEST_MESSAGE_HASHER) +TEST_TARGET_SUM_ENCODER = TargetSumEncoder(config=TEST_CONFIG, message_hasher=TEST_MESSAGE_HASHER) """A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index 751b95fa..0fd0c424 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -29,7 +29,7 @@ from itertools import chain from typing import List, Union -from pydantic import Field +from pydantic import Field, model_validator from lean_spec.types import StrictBaseModel, Uint64 @@ -79,18 +79,30 @@ class ChainTweak(StrictBaseModel): step: int = Field(ge=0, description="The step number within the chain (from 1 to BASE-1).") -class TweakHasher: +class TweakHasher(StrictBaseModel): """An instance of the Tweakable Hasher for a given config.""" - def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss): - """Initializes the hasher with a specific parameter set.""" - self.config = config - self.poseidon = poseidon_hasher + config: XmssConfig + """Configuration parameters for the hasher.""" - Tweak = Union[TreeTweak, ChainTweak] - """A type alias representing any valid tweak structure.""" + poseidon: PoseidonXmss + """Poseidon permutation instance for hashing.""" - def _encode_tweak(self, tweak: Tweak, length: int) -> List[Fp]: + @model_validator(mode="after") + def enforce_strict_types(self) -> "TweakHasher": + """Validates that only exact approved types are used (rejects subclasses).""" + from .poseidon import PoseidonXmss + + checks = {"config": XmssConfig, "poseidon": PoseidonXmss} + for field, expected in checks.items(): + if type(getattr(self, field)) is not expected: + raise TypeError( + f"{field} must be exactly {expected.__name__}, " + f"got {type(getattr(self, field)).__name__}" + ) + return self + + def _encode_tweak(self, tweak: Union[TreeTweak, ChainTweak], length: int) -> List[Fp]: """ Encodes a structured tweak object into a list of field elements. @@ -137,7 +149,7 @@ def _encode_tweak(self, tweak: Tweak, length: int) -> List[Fp]: def apply( self, parameter: Parameter, - tweak: Tweak, + tweak: Union[TreeTweak, ChainTweak], message_parts: List[HashDigest], ) -> HashDigest: """ @@ -246,8 +258,8 @@ def hash_chain( return current_digest -PROD_TWEAK_HASHER = TweakHasher(PROD_CONFIG, PROD_POSEIDON) +PROD_TWEAK_HASHER = TweakHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) """An instance configured for production-level parameters.""" -TEST_TWEAK_HASHER = TweakHasher(TEST_CONFIG, TEST_POSEIDON) +TEST_TWEAK_HASHER = TweakHasher(config=TEST_CONFIG, poseidon=TEST_POSEIDON) """A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index b9c3ea5f..ef1c7299 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -1,12 +1,11 @@ """Utility functions for the XMSS signature scheme.""" -import secrets from typing import TYPE_CHECKING, List from ...types.uint import Uint64 from ..koalabear import Fp, P -from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .containers import HashDigest, Parameter, Randomness +from .constants import XmssConfig +from .containers import HashDigest if TYPE_CHECKING: from .merkle_tree import MerkleTree @@ -15,38 +14,6 @@ from .tweak_hash import TweakHasher -class Rand: - """An instance of the random data generator for a given config.""" - - def __init__(self, config: XmssConfig): - """Initializes the generator with a specific parameter set.""" - self.config = config - - def field_elements(self, length: int) -> List[Fp]: - """Generates a random list of field elements.""" - # For each element, generate a secure random integer in the range [0, P-1]. - return [Fp(value=secrets.randbelow(P)) for _ in range(length)] - - def parameter(self) -> Parameter: - """Generates a random public parameter.""" - return self.field_elements(self.config.PARAMETER_LEN) - - def domain(self) -> HashDigest: - """Generates a random hash digest.""" - return self.field_elements(self.config.HASH_LEN_FE) - - def rho(self) -> Randomness: - """Generates randomness `rho` for message encoding.""" - return self.field_elements(self.config.RAND_LEN_FE) - - -PROD_RAND = Rand(PROD_CONFIG) -"""An instance configured for production-level parameters.""" - -TEST_RAND = Rand(TEST_CONFIG) -"""A lightweight instance for test environments.""" - - def int_to_base_p(value: int, num_limbs: int) -> List[Fp]: """ Decomposes a large integer into a list of base-P field elements. diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py index 1306829b..7fb451a2 100644 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -10,7 +10,8 @@ from lean_spec.subspecs.xmss.message_hash import ( TEST_MESSAGE_HASHER, ) -from lean_spec.subspecs.xmss.utils import TEST_RAND, int_to_base_p +from lean_spec.subspecs.xmss.rand import TEST_RAND +from lean_spec.subspecs.xmss.utils import int_to_base_p from lean_spec.types import Uint64 diff --git a/tests/lean_spec/subspecs/xmss/test_strict_types.py b/tests/lean_spec/subspecs/xmss/test_strict_types.py new file mode 100644 index 00000000..d3013ad8 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_strict_types.py @@ -0,0 +1,535 @@ +""" +Tests for strict type checking in XMSS component classes. + +These tests verify that Pydantic-based classes properly reject subclasses, +ensuring only approved implementations are used. +""" + +import pytest +from pydantic import ValidationError + +from lean_spec.subspecs.xmss.constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from lean_spec.subspecs.xmss.interface import GeneralizedXmssScheme +from lean_spec.subspecs.xmss.merkle_tree import PROD_MERKLE_TREE, MerkleTree +from lean_spec.subspecs.xmss.message_hash import PROD_MESSAGE_HASHER, MessageHasher +from lean_spec.subspecs.xmss.poseidon import PROD_POSEIDON, PoseidonXmss +from lean_spec.subspecs.xmss.prf import PROD_PRF, Prf +from lean_spec.subspecs.xmss.rand import PROD_RAND, Rand +from lean_spec.subspecs.xmss.target_sum import PROD_TARGET_SUM_ENCODER, TargetSumEncoder +from lean_spec.subspecs.xmss.tweak_hash import PROD_TWEAK_HASHER, TweakHasher + + +class TestPrfStrictTypes: + """Tests for Prf strict type checking.""" + + def test_prf_accepts_exact_type(self) -> None: + """Prf initialization succeeds with exact type.""" + prf = Prf(config=PROD_CONFIG) + assert prf.config == PROD_CONFIG + + def test_prf_rejects_subclass_config(self) -> None: + """Prf rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + Prf(config=custom_config) + + def test_prf_rejects_wrong_type_config(self) -> None: + """Prf rejects completely wrong type for config.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + Prf(config=RandomClass()) + + def test_prf_frozen(self) -> None: + """Prf is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_PRF.config = TEST_CONFIG + + +class TestRandStrictTypes: + """Tests for Rand strict type checking.""" + + def test_rand_accepts_exact_type(self) -> None: + """Rand initialization succeeds with exact type.""" + rand = Rand(config=PROD_CONFIG) + assert rand.config == PROD_CONFIG + + def test_rand_rejects_subclass_config(self) -> None: + """Rand rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + Rand(config=custom_config) + + def test_rand_rejects_wrong_type_config(self) -> None: + """Rand rejects completely wrong type for config.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + Rand(config=RandomClass()) + + def test_rand_frozen(self) -> None: + """Rand is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_RAND.config = TEST_CONFIG + + +class TestTweakHasherStrictTypes: + """Tests for TweakHasher strict type checking.""" + + def test_tweak_hasher_accepts_exact_types(self) -> None: + """TweakHasher initialization succeeds with exact types.""" + hasher = TweakHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) + assert hasher.config == PROD_CONFIG + + def test_tweak_hasher_rejects_subclass_config(self) -> None: + """TweakHasher rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + TweakHasher(config=custom_config, poseidon=PROD_POSEIDON) + + def test_tweak_hasher_rejects_subclass_poseidon(self) -> None: + """TweakHasher rejects PoseidonXmss subclass.""" + + class CustomPoseidon(PoseidonXmss): + pass + + custom_poseidon = PoseidonXmss.__new__(CustomPoseidon) + custom_poseidon.__dict__.update(PROD_POSEIDON.__dict__) + + with pytest.raises(TypeError, match="poseidon must be exactly PoseidonXmss"): + TweakHasher(config=PROD_CONFIG, poseidon=custom_poseidon) + + def test_tweak_hasher_rejects_wrong_type_config(self) -> None: + """TweakHasher rejects completely wrong type for config.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + TweakHasher(config=RandomClass(), poseidon=PROD_POSEIDON) + + def test_tweak_hasher_rejects_wrong_type_poseidon(self) -> None: + """TweakHasher rejects completely wrong type for poseidon.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + TweakHasher(config=PROD_CONFIG, poseidon=RandomClass()) + + def test_tweak_hasher_frozen(self) -> None: + """TweakHasher is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_TWEAK_HASHER.config = TEST_CONFIG + + +class TestMerkleTreeStrictTypes: + """Tests for MerkleTree strict type checking.""" + + def test_merkle_tree_accepts_exact_types(self) -> None: + """MerkleTree initialization succeeds with exact types.""" + tree = MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) + assert tree.config == PROD_CONFIG + + def test_merkle_tree_rejects_subclass_config(self) -> None: + """MerkleTree rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + MerkleTree(config=custom_config, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) + + def test_merkle_tree_rejects_subclass_hasher(self) -> None: + """MerkleTree rejects TweakHasher subclass.""" + + class CustomHasher(TweakHasher): + pass + + custom_hasher = TweakHasher.__new__(CustomHasher) + custom_hasher.__dict__.update(PROD_TWEAK_HASHER.__dict__) + + with pytest.raises(TypeError, match="hasher must be exactly TweakHasher"): + MerkleTree(config=PROD_CONFIG, hasher=custom_hasher, rand=PROD_RAND) + + def test_merkle_tree_rejects_subclass_rand(self) -> None: + """MerkleTree rejects Rand subclass.""" + + class CustomRand(Rand): + pass + + custom_rand = Rand.__new__(CustomRand) + custom_rand.__dict__.update(PROD_RAND.__dict__) + + with pytest.raises(TypeError, match="rand must be exactly Rand"): + MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=custom_rand) + + def test_merkle_tree_rejects_wrong_type_config(self) -> None: + """MerkleTree rejects completely wrong type for config.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + MerkleTree(config=RandomClass(), hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) + + def test_merkle_tree_rejects_wrong_type_hasher(self) -> None: + """MerkleTree rejects completely wrong type for hasher.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + MerkleTree(config=PROD_CONFIG, hasher=RandomClass(), rand=PROD_RAND) + + def test_merkle_tree_rejects_wrong_type_rand(self) -> None: + """MerkleTree rejects completely wrong type for rand.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=RandomClass()) + + def test_merkle_tree_frozen(self) -> None: + """MerkleTree is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_MERKLE_TREE.config = TEST_CONFIG + + +class TestTargetSumEncoderStrictTypes: + """Tests for TargetSumEncoder strict type checking.""" + + def test_encoder_accepts_exact_types(self) -> None: + """TargetSumEncoder initialization succeeds with exact types.""" + encoder = TargetSumEncoder(config=PROD_CONFIG, message_hasher=PROD_MESSAGE_HASHER) + assert encoder.config == PROD_CONFIG + + def test_encoder_rejects_subclass_config(self) -> None: + """TargetSumEncoder rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + TargetSumEncoder(config=custom_config, message_hasher=PROD_MESSAGE_HASHER) + + def test_encoder_rejects_subclass_message_hasher(self) -> None: + """TargetSumEncoder rejects MessageHasher subclass.""" + + class CustomMessageHasher(MessageHasher): + pass + + custom_hasher = MessageHasher.__new__(CustomMessageHasher) + custom_hasher.__dict__.update(PROD_MESSAGE_HASHER.__dict__) + + with pytest.raises(TypeError, match="message_hasher must be exactly MessageHasher"): + TargetSumEncoder(config=PROD_CONFIG, message_hasher=custom_hasher) + + def test_encoder_rejects_wrong_type_config(self) -> None: + """TargetSumEncoder rejects completely wrong type for config.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + TargetSumEncoder(config=RandomClass(), message_hasher=PROD_MESSAGE_HASHER) + + def test_encoder_rejects_wrong_type_message_hasher(self) -> None: + """TargetSumEncoder rejects completely wrong type for message_hasher.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + TargetSumEncoder(config=PROD_CONFIG, message_hasher=RandomClass()) + + def test_encoder_frozen(self) -> None: + """TargetSumEncoder is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_TARGET_SUM_ENCODER.config = TEST_CONFIG + + +class TestGeneralizedXmssSchemeStrictTypes: + """Tests for GeneralizedXmssScheme strict type checking (integration).""" + + def test_scheme_accepts_exact_types(self) -> None: + """GeneralizedXmssScheme initialization succeeds with exact types.""" + scheme = GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, + ) + assert scheme.config == PROD_CONFIG + + def test_scheme_rejects_subclass_config(self) -> None: + """GeneralizedXmssScheme rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + GeneralizedXmssScheme( + config=custom_config, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, + ) + + def test_scheme_rejects_subclass_prf(self) -> None: + """GeneralizedXmssScheme rejects Prf subclass.""" + + class CustomPrf(Prf): + pass + + custom_prf = Prf.__new__(CustomPrf) + custom_prf.__dict__.update(PROD_PRF.__dict__) + + with pytest.raises(TypeError, match="prf must be exactly Prf"): + GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=custom_prf, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, + ) + + def test_scheme_rejects_subclass_hasher(self) -> None: + """GeneralizedXmssScheme rejects TweakHasher subclass.""" + + class CustomHasher(TweakHasher): + pass + + custom_hasher = TweakHasher.__new__(CustomHasher) + custom_hasher.__dict__.update(PROD_TWEAK_HASHER.__dict__) + + with pytest.raises(TypeError, match="hasher must be exactly TweakHasher"): + GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=custom_hasher, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, + ) + + def test_scheme_rejects_subclass_merkle_tree(self) -> None: + """GeneralizedXmssScheme rejects MerkleTree subclass.""" + + class CustomMerkleTree(MerkleTree): + pass + + custom_tree = MerkleTree.__new__(CustomMerkleTree) + custom_tree.__dict__.update(PROD_MERKLE_TREE.__dict__) + + with pytest.raises(TypeError, match="merkle_tree must be exactly MerkleTree"): + GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=custom_tree, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, + ) + + def test_scheme_rejects_subclass_encoder(self) -> None: + """GeneralizedXmssScheme rejects TargetSumEncoder subclass.""" + + class CustomEncoder(TargetSumEncoder): + pass + + custom_encoder = TargetSumEncoder.__new__(CustomEncoder) + custom_encoder.__dict__.update(PROD_TARGET_SUM_ENCODER.__dict__) + + with pytest.raises(TypeError, match="encoder must be exactly TargetSumEncoder"): + GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=custom_encoder, + rand=PROD_RAND, + ) + + def test_scheme_rejects_subclass_rand(self) -> None: + """GeneralizedXmssScheme rejects Rand subclass.""" + + class CustomRand(Rand): + pass + + custom_rand = Rand.__new__(CustomRand) + custom_rand.__dict__.update(PROD_RAND.__dict__) + + with pytest.raises(TypeError, match="rand must be exactly Rand"): + GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=custom_rand, + ) + + def test_scheme_rejects_extra_fields(self) -> None: + """GeneralizedXmssScheme rejects extra fields.""" + with pytest.raises(ValidationError): + GeneralizedXmssScheme( + config=PROD_CONFIG, + prf=PROD_PRF, + hasher=PROD_TWEAK_HASHER, + merkle_tree=PROD_MERKLE_TREE, + encoder=PROD_TARGET_SUM_ENCODER, + rand=PROD_RAND, + extra_field="should_fail", + ) + + +class TestPoseidonXmssStrictTypes: + """Tests for PoseidonXmss strict type checking.""" + + def test_poseidon_accepts_exact_types(self) -> None: + """PoseidonXmss initialization succeeds with exact types.""" + poseidon = PoseidonXmss(params16=PROD_POSEIDON.params16, params24=PROD_POSEIDON.params24) + assert poseidon.params16 == PROD_POSEIDON.params16 + + def test_poseidon_rejects_subclass_params16(self) -> None: + """PoseidonXmss rejects Poseidon2Params subclass for params16.""" + from lean_spec.subspecs.poseidon2.permutation import Poseidon2Params + + class CustomParams(Poseidon2Params): + pass + + custom_params = Poseidon2Params.__new__(CustomParams) + custom_params.__dict__.update(PROD_POSEIDON.params16.__dict__) + + with pytest.raises(TypeError, match="params16 must be exactly Poseidon2Params"): + PoseidonXmss(params16=custom_params, params24=PROD_POSEIDON.params24) + + def test_poseidon_rejects_subclass_params24(self) -> None: + """PoseidonXmss rejects Poseidon2Params subclass for params24.""" + from lean_spec.subspecs.poseidon2.permutation import Poseidon2Params + + class CustomParams(Poseidon2Params): + pass + + custom_params = Poseidon2Params.__new__(CustomParams) + custom_params.__dict__.update(PROD_POSEIDON.params24.__dict__) + + with pytest.raises(TypeError, match="params24 must be exactly Poseidon2Params"): + PoseidonXmss(params16=PROD_POSEIDON.params16, params24=custom_params) + + def test_poseidon_rejects_wrong_type_params16(self) -> None: + """PoseidonXmss rejects completely wrong type for params16.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + PoseidonXmss(params16=RandomClass(), params24=PROD_POSEIDON.params24) + + def test_poseidon_rejects_wrong_type_params24(self) -> None: + """PoseidonXmss rejects completely wrong type for params24.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + PoseidonXmss(params16=PROD_POSEIDON.params16, params24=RandomClass()) + + def test_poseidon_frozen(self) -> None: + """PoseidonXmss is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_POSEIDON.params16 = PROD_POSEIDON.params24 + + +class TestMessageHasherStrictTypes: + """Tests for MessageHasher strict type checking.""" + + def test_message_hasher_accepts_exact_types(self) -> None: + """MessageHasher initialization succeeds with exact types.""" + hasher = MessageHasher(config=PROD_CONFIG, poseidon=PROD_POSEIDON) + assert hasher.config == PROD_CONFIG + + def test_message_hasher_rejects_subclass_config(self) -> None: + """MessageHasher rejects XmssConfig subclass.""" + + class CustomConfig(XmssConfig): + pass + + custom_config = XmssConfig.__new__(CustomConfig) + custom_config.__dict__.update(PROD_CONFIG.__dict__) + + with pytest.raises(TypeError, match="config must be exactly XmssConfig"): + MessageHasher(config=custom_config, poseidon=PROD_POSEIDON) + + def test_message_hasher_rejects_subclass_poseidon(self) -> None: + """MessageHasher rejects PoseidonXmss subclass.""" + + class CustomPoseidon(PoseidonXmss): + pass + + custom_poseidon = PoseidonXmss.__new__(CustomPoseidon) + custom_poseidon.__dict__.update(PROD_POSEIDON.__dict__) + + with pytest.raises(TypeError, match="poseidon must be exactly PoseidonXmss"): + MessageHasher(config=PROD_CONFIG, poseidon=custom_poseidon) + + def test_message_hasher_rejects_wrong_type_config(self) -> None: + """MessageHasher rejects completely wrong type for config.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + MessageHasher(config=RandomClass(), poseidon=PROD_POSEIDON) + + def test_message_hasher_rejects_wrong_type_poseidon(self) -> None: + """MessageHasher rejects completely wrong type for poseidon.""" + + class RandomClass: + pass + + with pytest.raises((TypeError, ValidationError)): + MessageHasher(config=PROD_CONFIG, poseidon=RandomClass()) + + def test_message_hasher_frozen(self) -> None: + """MessageHasher is immutable (frozen).""" + with pytest.raises(ValidationError): + PROD_MESSAGE_HASHER.config = TEST_CONFIG