diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 02e59b6e2..f8dbed939 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -15,7 +15,7 @@ pub use tree::{ }; use crate::sync::Arc; -use crate::{Miniscript, MiniscriptKey, ScriptContext, Terminal}; +use crate::{policy, Miniscript, MiniscriptKey, ScriptContext, Terminal}; impl<'a, Pk: MiniscriptKey, Ctx: ScriptContext> TreeLike for &'a Miniscript { fn as_node(&self) -> Tree { @@ -68,3 +68,29 @@ impl TreeLike for Arc } } } + +impl<'a, Pk: MiniscriptKey> TreeLike for &'a policy::Concrete { + fn as_node(&self) -> Tree { + use policy::Concrete::*; + match *self { + Unsatisfiable | Trivial | Key(_) | After(_) | Older(_) | Sha256(_) | Hash256(_) + | Ripemd160(_) | Hash160(_) => Tree::Nullary, + And(ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), + Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| p.as_ref()).collect()), + Threshold(_, ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), + } + } +} + +impl<'a, Pk: MiniscriptKey> TreeLike for Arc> { + fn as_node(&self) -> Tree { + use policy::Concrete::*; + match self.as_ref() { + Unsatisfiable | Trivial | Key(_) | After(_) | Older(_) | Sha256(_) | Hash256(_) + | Ripemd160(_) | Hash160(_) => Tree::Nullary, + And(ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), + Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| Arc::clone(p)).collect()), + Threshold(_, ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), + } + } +} diff --git a/src/miniscript/mod.rs b/src/miniscript/mod.rs index fa75a712c..28349e427 100644 --- a/src/miniscript/mod.rs +++ b/src/miniscript/mod.rs @@ -441,7 +441,6 @@ impl Miniscript { { let mut translated = vec![]; for data in Arc::new(self.clone()).post_order_iter() { - // convenience method to reduce typing let child_n = |n| Arc::clone(&translated[data.child_indices[n]]); let new_term = match data.node.node { diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 6880af9b0..66284f046 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -786,10 +786,14 @@ where } Concrete::And(ref subs) => { assert_eq!(subs.len(), 2, "and takes 2 args"); - let mut left = best_compilations(policy_cache, &subs[0], sat_prob, dissat_prob)?; - let mut right = best_compilations(policy_cache, &subs[1], sat_prob, dissat_prob)?; - let mut q_zero_right = best_compilations(policy_cache, &subs[1], sat_prob, None)?; - let mut q_zero_left = best_compilations(policy_cache, &subs[0], sat_prob, None)?; + let mut left = + best_compilations(policy_cache, subs[0].as_ref(), sat_prob, dissat_prob)?; + let mut right = + best_compilations(policy_cache, subs[1].as_ref(), sat_prob, dissat_prob)?; + let mut q_zero_right = + best_compilations(policy_cache, subs[1].as_ref(), sat_prob, None)?; + let mut q_zero_left = + best_compilations(policy_cache, subs[0].as_ref(), sat_prob, None)?; compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndB); compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndB); @@ -813,48 +817,56 @@ where let rw = subs[1].0 as f64 / total; //and-or - if let (Concrete::And(x), _) = (&subs[0].1, &subs[1].1) { + if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { let mut a1 = best_compilations( policy_cache, - &x[0], + x[0].as_ref(), lw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), )?; - let mut a2 = best_compilations(policy_cache, &x[0], lw * sat_prob, None)?; + let mut a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; let mut b1 = best_compilations( policy_cache, - &x[1], + x[1].as_ref(), lw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), )?; - let mut b2 = best_compilations(policy_cache, &x[1], lw * sat_prob, None)?; + let mut b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; - let mut c = - best_compilations(policy_cache, &subs[1].1, rw * sat_prob, dissat_prob)?; + let mut c = best_compilations( + policy_cache, + subs[1].1.as_ref(), + rw * sat_prob, + dissat_prob, + )?; compile_tern!(&mut a1, &mut b2, &mut c, [lw, rw]); compile_tern!(&mut b1, &mut a2, &mut c, [lw, rw]); }; - if let (_, Concrete::And(x)) = (&subs[0].1, &subs[1].1) { + if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { let mut a1 = best_compilations( policy_cache, - &x[0], + x[0].as_ref(), rw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), )?; - let mut a2 = best_compilations(policy_cache, &x[0], rw * sat_prob, None)?; + let mut a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; let mut b1 = best_compilations( policy_cache, - &x[1], + x[1].as_ref(), rw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), )?; - let mut b2 = best_compilations(policy_cache, &x[1], rw * sat_prob, None)?; + let mut b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; - let mut c = - best_compilations(policy_cache, &subs[0].1, lw * sat_prob, dissat_prob)?; + let mut c = best_compilations( + policy_cache, + subs[0].1.as_ref(), + lw * sat_prob, + dissat_prob, + )?; compile_tern!(&mut a1, &mut b2, &mut c, [rw, lw]); compile_tern!(&mut b1, &mut a2, &mut c, [rw, lw]); @@ -873,12 +885,22 @@ where let mut r_comp = vec![]; for dissat_prob in dissat_probs(rw).iter() { - let l = best_compilations(policy_cache, &subs[0].1, lw * sat_prob, *dissat_prob)?; + let l = best_compilations( + policy_cache, + subs[0].1.as_ref(), + lw * sat_prob, + *dissat_prob, + )?; l_comp.push(l); } for dissat_prob in dissat_probs(lw).iter() { - let r = best_compilations(policy_cache, &subs[1].1, rw * sat_prob, *dissat_prob)?; + let r = best_compilations( + policy_cache, + subs[1].1.as_ref(), + rw * sat_prob, + *dissat_prob, + )?; r_comp.push(r); } @@ -913,8 +935,8 @@ where let sp = sat_prob * k_over_n; //Expressions must be dissatisfiable let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); - let be = best(types::Base::B, policy_cache, ast, sp, dp)?; - let bw = best(types::Base::W, policy_cache, ast, sp, dp)?; + let be = best(types::Base::B, policy_cache, ast.as_ref(), sp, dp)?; + let bw = best(types::Base::W, policy_cache, ast.as_ref(), sp, dp)?; let diff = be.cost_1d(sp, dp) - bw.cost_1d(sp, dp); best_es.push((be.comp_ext_data, be)); @@ -947,7 +969,7 @@ where let key_vec: Vec = subs .iter() .filter_map(|s| { - if let Concrete::Key(ref pk) = *s { + if let Concrete::Key(ref pk) = s.as_ref() { Some(pk.clone()) } else { None @@ -967,9 +989,10 @@ where _ if k == subs.len() => { let mut it = subs.iter(); let mut policy = it.next().expect("No sub policy in thresh() ?").clone(); - policy = it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()])); + policy = + it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]).into()); - ret = best_compilations(policy_cache, &policy, sat_prob, dissat_prob)?; + ret = best_compilations(policy_cache, policy.as_ref(), sat_prob, dissat_prob)?; } _ => {} } @@ -1178,8 +1201,11 @@ mod tests { fn compile_timelocks() { // artificially create a policy that is problematic and try to compile let pol: SPolicy = Concrete::And(vec![ - Concrete::Key("A".to_string()), - Concrete::And(vec![Concrete::after(9), Concrete::after(1000_000_000)]), + Arc::new(Concrete::Key("A".to_string())), + Arc::new(Concrete::And(vec![ + Arc::new(Concrete::after(9)), + Arc::new(Concrete::after(1000_000_000)), + ])), ]); assert!(pol.compile::().is_err()); @@ -1273,13 +1299,22 @@ mod tests { // Liquid policy let policy: BPolicy = Concrete::Or(vec![ - (127, Concrete::Threshold(3, key_pol[0..5].to_owned())), + ( + 127, + Arc::new(Concrete::Threshold( + 3, + key_pol[0..5].iter().map(|p| (p.clone()).into()).collect(), + )), + ), ( 1, - Concrete::And(vec![ - Concrete::Older(Sequence::from_height(10000)), - Concrete::Threshold(2, key_pol[5..8].to_owned()), - ]), + Arc::new(Concrete::And(vec![ + Arc::new(Concrete::Older(Sequence::from_height(10000))), + Arc::new(Concrete::Threshold( + 2, + key_pol[5..8].iter().map(|p| (p.clone()).into()).collect(), + )), + ])), ), ]); @@ -1391,8 +1426,10 @@ mod tests { // and to a ms thresh otherwise. // k = 1 (or 2) does not compile, see https://github.com/rust-bitcoin/rust-miniscript/issues/114 for k in &[10, 15, 21] { - let pubkeys: Vec> = - keys.iter().map(|pubkey| Concrete::Key(*pubkey)).collect(); + let pubkeys: Vec>> = keys + .iter() + .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) + .collect(); let big_thresh = Concrete::Threshold(*k, pubkeys); let big_thresh_ms: SegwitMiniScript = big_thresh.compile().unwrap(); if *k == 21 { @@ -1419,18 +1456,18 @@ mod tests { // or(thresh(52, [pubkey; 52]), thresh(52, [pubkey; 52])) results in a 3642-bytes long // witness script with only 54 stack elements let (keys, _) = pubkeys_and_a_sig(104); - let keys_a: Vec> = keys[..keys.len() / 2] + let keys_a: Vec>> = keys[..keys.len() / 2] .iter() - .map(|pubkey| Concrete::Key(*pubkey)) + .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let keys_b: Vec> = keys[keys.len() / 2..] + let keys_b: Vec>> = keys[keys.len() / 2..] .iter() - .map(|pubkey| Concrete::Key(*pubkey)) + .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Concrete::Threshold(keys_a.len(), keys_a)), - (1, Concrete::Threshold(keys_b.len(), keys_b)), + (1, Arc::new(Concrete::Threshold(keys_a.len(), keys_a))), + (1, Arc::new(Concrete::Threshold(keys_b.len(), keys_b))), ]) .compile(); let script_size = thresh_res.clone().and_then(|m| Ok(m.script_size())); @@ -1443,8 +1480,10 @@ mod tests { // Hit the maximum witness stack elements limit let (keys, _) = pubkeys_and_a_sig(100); - let keys: Vec> = - keys.iter().map(|pubkey| Concrete::Key(*pubkey)).collect(); + let keys: Vec>> = keys + .iter() + .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) + .collect(); let thresh_res: Result = Concrete::Threshold(keys.len(), keys).compile(); let n_elements = thresh_res @@ -1462,8 +1501,10 @@ mod tests { fn shared_limits() { // Test the maximum number of OPs with a 67-of-68 multisig let (keys, _) = pubkeys_and_a_sig(68); - let keys: Vec> = - keys.iter().map(|pubkey| Concrete::Key(*pubkey)).collect(); + let keys: Vec>> = keys + .iter() + .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) + .collect(); let thresh_res: Result = Concrete::Threshold(keys.len() - 1, keys).compile(); let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count())); @@ -1475,8 +1516,10 @@ mod tests { ); // For legacy too.. let (keys, _) = pubkeys_and_a_sig(68); - let keys: Vec> = - keys.iter().map(|pubkey| Concrete::Key(*pubkey)).collect(); + let keys: Vec>> = keys + .iter() + .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) + .collect(); let thresh_res = Concrete::Threshold(keys.len() - 1, keys).compile::(); let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count())); assert_eq!( @@ -1488,8 +1531,9 @@ mod tests { // Test that we refuse to compile policies with duplicated keys let (keys, _) = pubkeys_and_a_sig(1); - let key = Concrete::Key(keys[0]); - let res = Concrete::Or(vec![(1, key.clone()), (1, key.clone())]).compile::(); + let key = Arc::new(Concrete::Key(keys[0])); + let res = + Concrete::Or(vec![(1, Arc::clone(&key)), (1, Arc::clone(&key))]).compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index bdc778d0c..ba88d7c02 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -19,13 +19,14 @@ use { crate::Miniscript, crate::Tap, core::cmp::Reverse, - sync::Arc, }; use super::ENTAILMENT_MAX_TERMINALS; use crate::expression::{self, FromTree}; +use crate::iter::TreeLike; use crate::miniscript::types::extra_props::TimelockInfo; use crate::prelude::*; +use crate::sync::Arc; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; use crate::{errstr, AbsLockTime, Error, ForEachKey, MiniscriptKey, Translator}; @@ -58,12 +59,12 @@ pub enum Policy { /// A HASH160 whose preimage must be provided to satisfy the descriptor. Hash160(Pk::Hash160), /// A list of sub-policies, all of which must be satisfied. - And(Vec>), + And(Vec>>), /// A list of sub-policies, one of which must be satisfied, along with /// relative probabilities for each one. - Or(Vec<(usize, Policy)>), + Or(Vec<(usize, Arc>)>), /// A set of descriptors, satisfactions must be provided for `k` of them. - Threshold(usize, Vec>), + Threshold(usize, Vec>>), } impl Policy @@ -81,105 +82,6 @@ where pub fn older(n: u32) -> Policy { Policy::Older(Sequence::from_consensus(n)) } } -/// Lightweight repr of Concrete policy which corresponds directly to a -/// Miniscript structure, and whose disjunctions are annotated with satisfaction -/// probabilities to assist the compiler -#[cfg(feature = "compiler")] -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum PolicyArc { - /// Unsatisfiable - Unsatisfiable, - /// Trivially satisfiable - Trivial, - /// A public key which must sign to satisfy the descriptor - Key(Pk), - /// An absolute locktime restriction - After(AbsLockTime), - /// A relative locktime restriction - Older(u32), - /// A SHA256 whose preimage must be provided to satisfy the descriptor - Sha256(Pk::Sha256), - /// A SHA256d whose preimage must be provided to satisfy the descriptor - Hash256(Pk::Hash256), - /// A RIPEMD160 whose preimage must be provided to satisfy the descriptor - Ripemd160(Pk::Ripemd160), - /// A HASH160 whose preimage must be provided to satisfy the descriptor - Hash160(Pk::Hash160), - /// A list of sub-policies' references, all of which must be satisfied - And(Vec>>), - /// A list of sub-policies's references, one of which must be satisfied, - /// along with relative probabilities for each one - Or(Vec<(usize, Arc>)>), - /// A set of descriptors' references, satisfactions must be provided for `k` of them - Threshold(usize, Vec>>), -} - -#[cfg(feature = "compiler")] -impl From> for Policy { - fn from(p: PolicyArc) -> Self { - match p { - PolicyArc::Unsatisfiable => Policy::Unsatisfiable, - PolicyArc::Trivial => Policy::Trivial, - PolicyArc::Key(pk) => Policy::Key(pk), - PolicyArc::After(t) => Policy::After(t), - PolicyArc::Older(t) => Policy::Older(Sequence::from_consensus(t)), - PolicyArc::Sha256(hash) => Policy::Sha256(hash), - PolicyArc::Hash256(hash) => Policy::Hash256(hash), - PolicyArc::Ripemd160(hash) => Policy::Ripemd160(hash), - PolicyArc::Hash160(hash) => Policy::Hash160(hash), - PolicyArc::And(subs) => Policy::And( - subs.into_iter() - .map(|pol| Self::from((*pol).clone())) - .collect(), - ), - PolicyArc::Or(subs) => Policy::Or( - subs.into_iter() - .map(|(odds, sub)| (odds, Self::from((*sub).clone()))) - .collect(), - ), - PolicyArc::Threshold(k, subs) => Policy::Threshold( - k, - subs.into_iter() - .map(|pol| Self::from((*pol).clone())) - .collect(), - ), - } - } -} - -#[cfg(feature = "compiler")] -impl From> for PolicyArc { - fn from(p: Policy) -> Self { - match p { - Policy::Unsatisfiable => PolicyArc::Unsatisfiable, - Policy::Trivial => PolicyArc::Trivial, - Policy::Key(pk) => PolicyArc::Key(pk), - Policy::After(lock_time) => PolicyArc::After(lock_time), - Policy::Older(Sequence(t)) => PolicyArc::Older(t), - Policy::Sha256(hash) => PolicyArc::Sha256(hash), - Policy::Hash256(hash) => PolicyArc::Hash256(hash), - Policy::Ripemd160(hash) => PolicyArc::Ripemd160(hash), - Policy::Hash160(hash) => PolicyArc::Hash160(hash), - Policy::And(subs) => PolicyArc::And( - subs.iter() - .map(|sub| Arc::new(Self::from(sub.clone()))) - .collect(), - ), - Policy::Or(subs) => PolicyArc::Or( - subs.iter() - .map(|(odds, sub)| (*odds, Arc::new(Self::from(sub.clone())))) - .collect(), - ), - Policy::Threshold(k, subs) => PolicyArc::Threshold( - k, - subs.iter() - .map(|sub| Arc::new(Self::from(sub.clone()))) - .collect(), - ), - } - } -} - /// Detailed error type for concrete policies. #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum PolicyError { @@ -326,7 +228,7 @@ impl Policy { let key_prob_map: BTreeMap<_, _> = self .to_tapleaf_prob_vec(1.0) .into_iter() - .filter(|(_, ref pol)| matches!(*pol, Concrete::Key(..))) + .filter(|(_, ref pol)| matches!(pol, Concrete::Key(..))) .map(|(prob, key)| (key, prob)) .collect(); @@ -441,16 +343,14 @@ impl Policy { match policy { Policy::Trivial => None, policy => { - let pol = PolicyArc::from(policy); - let leaf_compilations: Vec<_> = pol + let leaf_compilations: Vec<_> = policy .enumerate_policy_tree(1.0) .into_iter() - .filter(|x| x.1 != Arc::new(PolicyArc::Unsatisfiable)) - .map(|(prob, ref pol)| { - let converted_pol = Policy::::from((**pol).clone()); + .filter(|x| x.1 != Arc::new(Policy::Unsatisfiable)) + .map(|(prob, pol)| { ( OrdF64(prob), - compiler::best_compilation(&converted_pol).unwrap(), + compiler::best_compilation(pol.as_ref()).unwrap(), ) }) .collect(); @@ -512,7 +412,7 @@ impl Policy { } #[cfg(feature = "compiler")] -impl PolicyArc { +impl Policy { /// Returns a vector of policies whose disjunction is isomorphic to the initial one. /// /// This function is supposed to incrementally expand i.e. represent the policy as @@ -521,21 +421,19 @@ impl PolicyArc { #[cfg(feature = "compiler")] fn enumerate_pol(&self, prob: f64) -> Vec<(f64, Arc)> { match self { - PolicyArc::Or(subs) => { + Policy::Or(subs) => { let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); subs.iter() .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) .collect::>() } - PolicyArc::Threshold(k, subs) if *k == 1 => { + Policy::Threshold(k, subs) if *k == 1 => { let total_odds = subs.len(); subs.iter() .map(|pol| (prob / total_odds as f64, pol.clone())) .collect::>() } - PolicyArc::Threshold(k, subs) if *k != subs.len() => { - generate_combination(subs, prob, *k) - } + Policy::Threshold(k, subs) if *k != subs.len() => generate_combination(subs, prob, *k), pol => vec![(prob, Arc::new(pol.clone()))], } } @@ -573,7 +471,7 @@ impl PolicyArc { 'outer: loop { //--- FIND a plausible node --- let mut prob: Reverse = Reverse(OrdF64(0.0)); - let mut curr_policy: Arc = Arc::new(PolicyArc::Unsatisfiable); + let mut curr_policy: Arc = Arc::new(Policy::Unsatisfiable); let mut curr_pol_replace_vec: Vec<(f64, Arc)> = vec![]; let mut no_more_enum = false; @@ -646,30 +544,14 @@ impl PolicyArc { impl ForEachKey for Policy { fn for_each_key<'a, F: FnMut(&'a Pk) -> bool>(&'a self, mut pred: F) -> bool { - self.real_for_each_key(&mut pred) + self.pre_order_iter().all(|policy| match policy { + Policy::Key(ref pk) => pred(pk), + _ => true, + }) } } impl Policy { - fn real_for_each_key<'a, F: FnMut(&'a Pk) -> bool>(&'a self, pred: &mut F) -> bool { - match *self { - Policy::Unsatisfiable | Policy::Trivial => true, - Policy::Key(ref pk) => pred(pk), - Policy::Sha256(..) - | Policy::Hash256(..) - | Policy::Ripemd160(..) - | Policy::Hash160(..) - | Policy::After(..) - | Policy::Older(..) => true, - Policy::Threshold(_, ref subs) | Policy::And(ref subs) => { - subs.iter().all(|sub| sub.real_for_each_key(&mut *pred)) - } - Policy::Or(ref subs) => subs - .iter() - .all(|(_, sub)| sub.real_for_each_key(&mut *pred)), - } - } - /// Converts a policy using one kind of public key to another type of public key. /// /// For example usage please see [`crate::policy::semantic::Policy::translate_pk`]. @@ -678,94 +560,99 @@ impl Policy { T: Translator, Q: MiniscriptKey, { - self._translate_pk(t) - } - - fn _translate_pk(&self, t: &mut T) -> Result, E> - where - T: Translator, - Q: MiniscriptKey, - { - match *self { - Policy::Unsatisfiable => Ok(Policy::Unsatisfiable), - Policy::Trivial => Ok(Policy::Trivial), - Policy::Key(ref pk) => t.pk(pk).map(Policy::Key), - Policy::Sha256(ref h) => t.sha256(h).map(Policy::Sha256), - Policy::Hash256(ref h) => t.hash256(h).map(Policy::Hash256), - Policy::Ripemd160(ref h) => t.ripemd160(h).map(Policy::Ripemd160), - Policy::Hash160(ref h) => t.hash160(h).map(Policy::Hash160), - Policy::Older(n) => Ok(Policy::Older(n)), - Policy::After(n) => Ok(Policy::After(n)), - Policy::Threshold(k, ref subs) => { - let new_subs: Result>, _> = - subs.iter().map(|sub| sub._translate_pk(t)).collect(); - new_subs.map(|ok| Policy::Threshold(k, ok)) - } - Policy::And(ref subs) => Ok(Policy::And( - subs.iter() - .map(|sub| sub._translate_pk(t)) - .collect::>, E>>()?, - )), - Policy::Or(ref subs) => Ok(Policy::Or( - subs.iter() - .map(|(prob, sub)| Ok((*prob, sub._translate_pk(t)?))) - .collect::)>, E>>()?, - )), + use Policy::*; + + let mut translated = vec![]; + for data in self.post_order_iter() { + let child_n = |n| Arc::clone(&translated[data.child_indices[n]]); + + let new_policy = match data.node { + Unsatisfiable => Unsatisfiable, + Trivial => Trivial, + Key(ref pk) => t.pk(pk).map(Key)?, + Sha256(ref h) => t.sha256(h).map(Sha256)?, + Hash256(ref h) => t.hash256(h).map(Hash256)?, + Ripemd160(ref h) => t.ripemd160(h).map(Ripemd160)?, + Hash160(ref h) => t.hash160(h).map(Hash160)?, + Older(ref n) => Older(*n), + After(ref n) => After(*n), + Threshold(ref k, ref subs) => Threshold(*k, (0..subs.len()).map(child_n).collect()), + And(ref subs) => And((0..subs.len()).map(child_n).collect()), + Or(ref subs) => Or(subs + .iter() + .enumerate() + .map(|(i, (prob, _))| (*prob, child_n(i))) + .collect()), + }; + translated.push(Arc::new(new_policy)); } + // Unwrap is ok because we know we processed at least one node. + let root_node = translated.pop().unwrap(); + // Unwrap is ok because we know `root_node` is the only strong reference. + Ok(Arc::try_unwrap(root_node).unwrap()) } /// Translates `Concrete::Key(key)` to `Concrete::Unsatisfiable` when extracting `TapKey`. pub fn translate_unsatisfiable_pk(self, key: &Pk) -> Policy { - match self { - Policy::Key(ref k) if k.clone() == *key => Policy::Unsatisfiable, - Policy::And(subs) => Policy::And( - subs.into_iter() - .map(|sub| sub.translate_unsatisfiable_pk(key)) - .collect::>(), - ), - Policy::Or(subs) => Policy::Or( - subs.into_iter() - .map(|(k, sub)| (k, sub.translate_unsatisfiable_pk(key))) - .collect::>(), - ), - Policy::Threshold(k, subs) => Policy::Threshold( - k, - subs.into_iter() - .map(|sub| sub.translate_unsatisfiable_pk(key)) - .collect::>(), - ), - x => x, + use Policy::*; + + let mut translated = vec![]; + for data in Arc::new(self).post_order_iter() { + let child_n = |n| Arc::clone(&translated[data.child_indices[n]]); + + let new_policy = match data.node.as_ref() { + Policy::Key(ref k) if k.clone() == *key => Some(Policy::Unsatisfiable), + Threshold(k, ref subs) => { + Some(Threshold(*k, (0..subs.len()).map(child_n).collect())) + } + And(ref subs) => Some(And((0..subs.len()).map(child_n).collect())), + Or(ref subs) => Some(Or(subs + .iter() + .enumerate() + .map(|(i, (prob, _))| (*prob, child_n(i))) + .collect())), + _ => None, + }; + match new_policy { + Some(new_policy) => translated.push(Arc::new(new_policy)), + None => translated.push(Arc::clone(&data.node)), + } } + // Ok to unwrap because we know we processed at least one node. + let root_node = translated.pop().unwrap(); + // Ok to unwrap because we know `root_node` is the only strong reference. + Arc::try_unwrap(root_node).unwrap() } /// Gets all keys in the policy. pub fn keys(&self) -> Vec<&Pk> { - match *self { - Policy::Key(ref pk) => vec![pk], - Policy::Threshold(_k, ref subs) => { - subs.iter().flat_map(|sub| sub.keys()).collect::>() - } - Policy::And(ref subs) => subs.iter().flat_map(|sub| sub.keys()).collect::>(), - Policy::Or(ref subs) => subs - .iter() - .flat_map(|(ref _k, ref sub)| sub.keys()) - .collect::>(), - // map all hashes and time - _ => vec![], - } + self.pre_order_iter() + .filter_map(|policy| match policy { + Policy::Key(ref pk) => Some(pk), + _ => None, + }) + .collect() } /// Gets the number of [TapLeaf](`TapTree::Leaf`)s considering exhaustive root-level [`Policy::Or`] /// and [`Policy::Threshold`] disjunctions for the `TapTree`. #[cfg(feature = "compiler")] fn num_tap_leaves(&self) -> usize { - match self { - Policy::Or(subs) => subs.iter().map(|(_prob, pol)| pol.num_tap_leaves()).sum(), - Policy::Threshold(k, subs) if *k == 1 => { - subs.iter().map(|pol| pol.num_tap_leaves()).sum() - } - _ => 1, + use Policy::*; + + let mut nums = vec![]; + for data in Arc::new(self).post_order_iter() { + let num_for_child_n = |n| nums[data.child_indices[n]]; + + let num = match data.node { + Or(subs) => (0..subs.len()).map(num_for_child_n).sum(), + Threshold(k, subs) if *k == 1 => (0..subs.len()).map(num_for_child_n).sum(), + _ => 1, + }; + nums.push(num); } + // Ok to unwrap because we know we processed at least one node. + nums.pop().unwrap() } /// Does checks on the number of `TapLeaf`s. @@ -798,53 +685,60 @@ impl Policy { /// Returns an error if there is at least one satisfaction that contains /// a combination of heightlock and timelock. pub fn check_timelocks(&self) -> Result<(), PolicyError> { - let timelocks = self.check_timelocks_helper(); - if timelocks.contains_combination { + let aggregated_timelock_info = self.timelock_info(); + if aggregated_timelock_info.contains_combination { Err(PolicyError::HeightTimelockCombination) } else { Ok(()) } } - // Checks whether the given concrete policy contains a combination of - // timelocks and heightlocks - fn check_timelocks_helper(&self) -> TimelockInfo { - // timelocks[csv_h, csv_t, cltv_h, cltv_t, combination] - match *self { - Policy::Unsatisfiable - | Policy::Trivial - | Policy::Key(_) - | Policy::Sha256(_) - | Policy::Hash256(_) - | Policy::Ripemd160(_) - | Policy::Hash160(_) => TimelockInfo::default(), - Policy::After(t) => TimelockInfo { - csv_with_height: false, - csv_with_time: false, - cltv_with_height: absolute::LockTime::from(t).is_block_height(), - cltv_with_time: absolute::LockTime::from(t).is_block_time(), - contains_combination: false, - }, - Policy::Older(t) => TimelockInfo { - csv_with_height: t.is_height_locked(), - csv_with_time: t.is_time_locked(), - cltv_with_height: false, - cltv_with_time: false, - contains_combination: false, - }, - Policy::Threshold(k, ref subs) => { - let iter = subs.iter().map(|sub| sub.check_timelocks_helper()); - TimelockInfo::combine_threshold(k, iter) - } - Policy::And(ref subs) => { - let iter = subs.iter().map(|sub| sub.check_timelocks_helper()); - TimelockInfo::combine_threshold(subs.len(), iter) - } - Policy::Or(ref subs) => { - let iter = subs.iter().map(|(_p, sub)| sub.check_timelocks_helper()); - TimelockInfo::combine_threshold(1, iter) - } + /// Processes `Policy` using `post_order_iter`, creates a `TimelockInfo` for each `Nullary` node + /// and combines them together for `Nary` nodes. + /// + /// # Returns + /// + /// A single `TimelockInfo` that is the combination of all others after processing each node. + fn timelock_info(&self) -> TimelockInfo { + use Policy::*; + + let mut infos = vec![]; + for data in Arc::new(self).post_order_iter() { + let info_for_child_n = |n| infos[data.child_indices[n]]; + + let info = match data.node { + Policy::After(ref t) => TimelockInfo { + csv_with_height: false, + csv_with_time: false, + cltv_with_height: absolute::LockTime::from(*t).is_block_height(), + cltv_with_time: absolute::LockTime::from(*t).is_block_time(), + contains_combination: false, + }, + Policy::Older(ref t) => TimelockInfo { + csv_with_height: t.is_height_locked(), + csv_with_time: t.is_time_locked(), + cltv_with_height: false, + cltv_with_time: false, + contains_combination: false, + }, + Threshold(ref k, subs) => { + let iter = (0..subs.len()).map(info_for_child_n); + TimelockInfo::combine_threshold(*k, iter) + } + And(ref subs) => { + let iter = (0..subs.len()).map(info_for_child_n); + TimelockInfo::combine_threshold(subs.len(), iter) + } + Or(ref subs) => { + let iter = (0..subs.len()).map(info_for_child_n); + TimelockInfo::combine_threshold(1, iter) + } + _ => TimelockInfo::default(), + }; + infos.push(info); } + // Ok to unwrap, we had to have visited at least one node. + infos.pop().unwrap() } /// This returns whether the given policy is valid or not. It maybe possible that the policy @@ -1127,7 +1021,7 @@ impl_block_str!( } let mut subs = Vec::with_capacity(top.args.len()); for arg in &top.args { - subs.push(Policy::from_tree(arg)?); + subs.push(Arc::new(Policy::from_tree(arg)?)); } Ok(Policy::And(subs)) } @@ -1139,7 +1033,7 @@ impl_block_str!( for arg in &top.args { subs.push(Policy::from_tree_prob(arg, true)?); } - Ok(Policy::Or(subs)) + Ok(Policy::Or(subs.into_iter().map(|(prob, sub)| (prob, Arc::new(sub))).collect())) } ("thresh", nsubs) => { if top.args.is_empty() || !top.args[0].args.is_empty() { @@ -1155,7 +1049,7 @@ impl_block_str!( for arg in &top.args[1..] { subs.push(Policy::from_tree(arg)?); } - Ok(Policy::Threshold(thresh as usize, subs)) + Ok(Policy::Threshold(thresh as usize, subs.into_iter().map(Arc::new).collect())) } _ => Err(errstr(top.name)), } @@ -1207,20 +1101,20 @@ fn with_huffman_tree( /// any one of the conditions exclusively. #[cfg(feature = "compiler")] fn generate_combination( - policy_vec: &Vec>>, + policy_vec: &Vec>>, prob: f64, k: usize, -) -> Vec<(f64, Arc>)> { +) -> Vec<(f64, Arc>)> { debug_assert!(k <= policy_vec.len()); - let mut ret: Vec<(f64, Arc>)> = vec![]; + let mut ret: Vec<(f64, Arc>)> = vec![]; for i in 0..policy_vec.len() { - let policies: Vec>> = policy_vec + let policies: Vec>> = policy_vec .iter() .enumerate() .filter_map(|(j, sub)| if j != i { Some(Arc::clone(sub)) } else { None }) .collect(); - ret.push((prob / policy_vec.len() as f64, Arc::new(PolicyArc::Threshold(k, policies)))); + ret.push((prob / policy_vec.len() as f64, Arc::new(Policy::Threshold(k, policies)))); } ret } @@ -1231,58 +1125,49 @@ mod compiler_tests { use sync::Arc; - use super::Concrete; - use crate::policy::concrete::{generate_combination, PolicyArc}; - use crate::prelude::*; + use super::*; #[test] fn test_gen_comb() { - let policies: Vec> = vec!["pk(A)", "pk(B)", "pk(C)", "pk(D)"] + let policies: Vec>> = vec!["pk(A)", "pk(B)", "pk(C)", "pk(D)"] .into_iter() .map(|st| policy_str!("{}", st)) + .map(|p| Arc::new(p)) .collect(); - let policy_vec = policies - .into_iter() - .map(|pol| Arc::new(PolicyArc::from(pol))) - .collect::>(); - let combinations = generate_combination(&policy_vec, 1.0, 2); + let combinations = generate_combination(&policies, 1.0, 2); - let comb_a: Vec>> = vec![ + let comb_a: Vec> = vec![ policy_str!("pk(B)"), policy_str!("pk(C)"), policy_str!("pk(D)"), - ] - .into_iter() - .map(|pol| Arc::new(PolicyArc::from(pol))) - .collect(); - let comb_b: Vec>> = vec![ + ]; + let comb_b: Vec> = vec![ policy_str!("pk(A)"), policy_str!("pk(C)"), policy_str!("pk(D)"), - ] - .into_iter() - .map(|pol| Arc::new(PolicyArc::from(pol))) - .collect(); - let comb_c: Vec>> = vec![ + ]; + let comb_c: Vec> = vec![ policy_str!("pk(A)"), policy_str!("pk(B)"), policy_str!("pk(D)"), - ] - .into_iter() - .map(|pol| Arc::new(PolicyArc::from(pol))) - .collect(); - let comb_d: Vec>> = vec![ + ]; + let comb_d: Vec> = vec![ policy_str!("pk(A)"), policy_str!("pk(B)"), policy_str!("pk(C)"), - ] - .into_iter() - .map(|pol| Arc::new(PolicyArc::from(pol))) - .collect(); + ]; let expected_comb = vec![comb_a, comb_b, comb_c, comb_d] .into_iter() - .map(|sub_pol| (0.25, Arc::new(PolicyArc::Threshold(2, sub_pol)))) + .map(|sub_pol| { + ( + 0.25, + Arc::new(Policy::Threshold( + 2, + sub_pol.into_iter().map(|p| Arc::new(p)).collect(), + )), + ) + }) .collect::>(); assert_eq!(combinations, expected_comb); } @@ -1295,7 +1180,7 @@ mod tests { use super::*; #[test] - fn for_each_key() { + fn for_each_key_count_keys() { let liquid_pol = Policy::::from_str( "or(and(older(4096),thresh(2,pk(A),pk(B),pk(C))),thresh(11,pk(F1),pk(F2),pk(F3),pk(F4),pk(F5),pk(F6),pk(F7),pk(F8),pk(F9),pk(F10),pk(F11),pk(F12),pk(F13),pk(F14)))").unwrap(); let mut count = 0; @@ -1305,4 +1190,69 @@ mod tests { })); assert_eq!(count, 17); } + + #[test] + fn for_each_key_fails_predicate() { + let policy = + Policy::::from_str("or(and(pk(key0),pk(key1)),pk(oddnamedkey))").unwrap(); + assert!(!policy.for_each_key(|k| k.starts_with("key"))); + } + + #[test] + fn tranaslate_pk() { + pub struct TestTranslator; + impl Translator for TestTranslator { + fn pk(&mut self, pk: &String) -> Result { + let new = format!("NEW-{}", pk); + Ok(new.to_string()) + } + fn sha256(&mut self, hash: &String) -> Result { Ok(hash.to_string()) } + fn hash256(&mut self, hash: &String) -> Result { Ok(hash.to_string()) } + fn ripemd160(&mut self, hash: &String) -> Result { Ok(hash.to_string()) } + fn hash160(&mut self, hash: &String) -> Result { Ok(hash.to_string()) } + } + let policy = Policy::::from_str("or(and(pk(A),pk(B)),pk(C))").unwrap(); + let mut t = TestTranslator; + + let want = Policy::::from_str("or(and(pk(NEW-A),pk(NEW-B)),pk(NEW-C))").unwrap(); + let got = policy + .translate_pk(&mut t) + .expect("failed to translate keys"); + + assert_eq!(got, want); + } + + #[test] + fn translate_unsatisfiable_pk() { + let policy = Policy::::from_str("or(and(pk(A),pk(B)),pk(C))").unwrap(); + + let want = Policy::::from_str("or(and(pk(A),UNSATISFIABLE),pk(C))").unwrap(); + let got = policy.translate_unsatisfiable_pk(&"B".to_string()); + + assert_eq!(got, want); + } + + #[test] + fn keys() { + let policy = Policy::::from_str("or(and(pk(A),pk(B)),pk(C))").unwrap(); + + let want = vec!["A", "B", "C"]; + let got = policy.keys(); + + assert_eq!(got, want); + } + + #[test] + #[cfg(feature = "compiler")] + fn num_tap_leaves() { + let policy = Policy::::from_str("or(and(pk(A),pk(B)),pk(C))").unwrap(); + assert_eq!(policy.num_tap_leaves(), 2); + } + + #[test] + #[should_panic] + fn check_timelocks() { + // This implicitly tests the check_timelocks API (has height and time locks). + let _ = Policy::::from_str("and(after(10),after(500000000))").unwrap(); + } } diff --git a/src/policy/mod.rs b/src/policy/mod.rs index c2d4fdbf5..138ec45b7 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -22,6 +22,7 @@ pub use self::concrete::Policy as Concrete; pub use self::semantic::Policy as Semantic; use crate::descriptor::Descriptor; use crate::miniscript::{Miniscript, ScriptContext}; +use crate::sync::Arc; use crate::{Error, MiniscriptKey, Terminal}; /// Policy entailment algorithm maximum number of terminals allowed. @@ -213,6 +214,9 @@ impl Liftable for Concrete { Ok(ret) } } +impl Liftable for Arc> { + fn lift(&self) -> Result, Error> { self.as_ref().lift() } +} #[cfg(test)] mod tests {