diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 0000000000..fb33191558 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,361 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Error: Error; + type Result; // Would like to default to () but currently unstable + + fn run(&self, hugr: &mut impl HugrMut) -> Result; + + fn map_err( + self, + f: impl Fn(Self::Error) -> E2, + ) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` can be combined with ours. + fn then>( + self, + other: P, + ) -> impl ComposablePass { + struct Sequence(P1, P2, PhantomData); + impl ComposablePass for Sequence + where + P1: ComposablePass, + P2: ComposablePass, + E: ErrorCombiner, + { + type Error = E; + + type Result = (P1::Result, P2::Result); + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res1 = self.0.run(hugr).map_err(E::from_first)?; + let res2 = self.1.run(hugr).map_err(E::from_second)?; + Ok((res1, res2)) + } + } + + Sequence(self, other, PhantomData) + } +} + +/// Trait for combining the error types from two different passes +/// into a single error. +pub trait ErrorCombiner: Error { + fn from_first(a: A) -> Self; + fn from_second(b: B) -> Self; +} + +impl> ErrorCombiner for A { + fn from_first(a: A) -> Self { + a + } + + fn from_second(b: B) -> Self { + b.into() + } +} + +impl ErrorCombiner for Either { + fn from_first(a: A) -> Self { + Either::Left(a) + } + + fn from_second(b: B) -> Self { + Either::Right(b) + } +} + +// Note: in the short term we could wish for two more impls: +// impl ErrorCombiner for E +// impl ErrorCombiner for E +// however, these aren't possible as they conflict with +// impl> ErrorCombiner for A +// when A=E=Infallible, boo :-(. +// However this will become possible, indeed automatic, when Infallible is replaced +// by ! (never_type) as (unlike Infallible) ! converts Into anything + +// ErrMapper ------------------------------ +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Error = E; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr).map_err(&self.1) + } +} + +// ValidatingPass ------------------------------ + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P, bool); + +impl ValidatingPass

{ + pub fn new_default(underlying: P) -> Self { + // Self(underlying, cfg!(feature = "extension_inference")) + // Sadly, many tests fail with extension inference, hence: + Self(underlying, false) + } + + pub fn new_validating_extensions(underlying: P) -> Self { + Self(underlying, true) + } + + pub fn new(underlying: P, validate_extensions: bool) -> Self { + Self(underlying, validate_extensions) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + match self.1 { + false => hugr.validate_no_extensions(), + true => hugr.validate(), + } + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Error = ValidatePassError; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + })?; + Ok(res) + } +} + +// IfThen ------------------------------ +/// [ComposablePass] that executes a first pass that returns a `bool` +/// result; and then, if-and-only-if that first result was true, +/// executes a second pass +pub struct IfThen(A, B, PhantomData); + +impl, B: ComposablePass, E: ErrorCombiner> + IfThen +{ + /// Make a new instance given the [ComposablePass] to run first + /// and (maybe) second + pub fn new(fst: A, opt_snd: B) -> Self { + Self(fst, opt_snd, PhantomData) + } +} + +impl, B: ComposablePass, E: ErrorCombiner> + ComposablePass for IfThen +{ + type Error = E; + + type Result = Option; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; + res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) + .transpose() + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result> { + if cfg!(test) { + ValidatingPass::new_default(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use itertools::{Either, Itertools}; + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use hugr_core::extension::prelude::{ + bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, + }; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{Signature, TypeRow}; + use hugr_core::{Hugr, HugrView, IncomingPort}; + + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::untuple::{UntupleRecursive, UntupleResult}; + use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + + use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + #[test] + fn test_then() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<_, Either> = + dce.clone().then(cfold.clone()).run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r = dce + .clone() + .map_err(|inf| match inf {}) + .then(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + + let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone()); + r2.unwrap(); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold, false).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } + + #[test] + fn test_if_then() { + let tr = TypeRow::from(vec![usize_t(); 2]); + + let h = { + let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); + let [a, b] = fb.input_wires_arr(); + let tup = fb + .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b]) + .unwrap(); + let untup = fb + .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs()) + .unwrap(); + fb.finish_hugr_with_outputs(untup.outputs()).unwrap() + }; + + let untup = UntuplePass::new(UntupleRecursive::Recursive); + { + // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple + let mut repl = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + repl.replace_type(usize_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup.clone()); + + let mut h = h.clone(); + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!( + r, + Some(UntupleResult { + rewrites_applied: 1 + }) + ); + let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]); + } + + // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple + let mut repl = ReplaceTypes::default(); + let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone(); + repl.replace_type(i32_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup); + let mut h = h; + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!(r, None); + assert_eq!(h.children(h.root()).count(), 4); + let mktup = h + .output_neighbours(h.first_child(h.root()).unwrap()) + .next() + .unwrap(); + assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr))); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index e73e3cd0eb..99ccc180c7 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -21,12 +21,11 @@ use crate::dataflow::{ TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{composable::validate_if_test, ComposablePass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -34,13 +33,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -49,12 +45,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -86,9 +76,19 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Error = ConstFoldError; + type Result = (); /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -164,23 +164,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -198,7 +185,7 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } struct ConstFoldContext; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 58e69c568c..ff5cd93a5c 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd1..899e30243a 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -185,6 +157,22 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Error = Infallible; + type Result = (); + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; @@ -196,6 +184,8 @@ mod test { use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e428..7071d53358 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,32 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Error = RemoveDeadFuncsError; + type Result = (); + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -118,16 +127,11 @@ impl RemoveDeadFuncsPass { pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +186,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 961c4da471..83ff71b675 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -21,19 +23,11 @@ pub mod untuple; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod replace_types; pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2f..875ee93557 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, ops::Deref, }; @@ -12,7 +13,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +33,8 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -254,8 +239,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +254,25 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -387,9 +357,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -410,7 +380,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -472,7 +442,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +463,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -601,7 +571,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +632,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +689,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a98..e81a640e3f 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -26,7 +26,7 @@ use hugr_core::types::{ }; use hugr_core::{Hugr, HugrView, Node, Wire}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -143,7 +143,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - validation: ValidationLevel, } impl Default for ReplaceTypes { @@ -184,8 +183,6 @@ pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] - ValidationError(#[from] ValidatePassError), - #[error(transparent)] ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), @@ -203,16 +200,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - validation: Default::default(), } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this @@ -323,36 +313,6 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { - let mut changed = false; - for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.root()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - } - Ok(changed) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) @@ -472,11 +432,40 @@ impl ReplaceTypes { false } }), - Value::Function { hugr } => self.run_no_validate(&mut **hugr), + Value::Function { hugr } => self.run(&mut **hugr), } } } +impl ComposablePass for ReplaceTypes { + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + } + Ok(changed) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -532,29 +521,26 @@ mod test { use hugr_core::extension::prelude::{ bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - - use hugr_core::ops::constant::OpaqueValue; - use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::hugr::{IdentList, ValidationError}; + use hugr_core::ops::{ + constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, + }; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, }; - - use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; - use crate::validation::ValidatePassError; + use crate::ComposablePass; - use super::ReplaceTypesError; use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; @@ -979,13 +965,16 @@ mod test { let cu = cst.value().downcast_ref::().unwrap(); Ok(ConstInt::new_u(6, cu.value())?.into()) }); + + let mut h = backup.clone(); + repl.run(&mut h).unwrap(); // No validation here assert!( - matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts {from, to, ..}, .. - })) if backup.get_optype(from).is_const() && to == c.node()) + matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) + if backup.get_optype(from).is_const() && to == c.node()) ); repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; - repl.run(&mut h).unwrap(); // Includes validation + repl.run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5b4da71846..bc508bd539 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -377,7 +377,7 @@ mod test { use crate::replace_types::handlers::linearize_array; use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; - use crate::ReplaceTypes; + use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index dbe04edd19..874fd9ec3d 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -10,19 +10,19 @@ use hugr_core::hugr::views::SiblingSubgraph; use hugr_core::hugr::SimpleReplacementError; use hugr_core::ops::{NamedOp, OpTrait, OpType}; use hugr_core::types::Type; -use hugr_core::{HugrView, SimpleReplacement}; +use hugr_core::{HugrView, Node, SimpleReplacement}; use itertools::Itertools; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum UntupleRecursive { - /// Traverse the HUGR recursively. + /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, - /// Do not traverse the HUGR recursively. + /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph #[default] NonRecursive, } @@ -48,22 +48,20 @@ pub enum UntupleRecursive { pub struct UntuplePass { /// Whether to traverse the HUGR recursively. recursive: UntupleRecursive, - /// The level of validation to perform on the rewrite. - validation: ValidationLevel, + /// Parent node under which to operate; None indicates the Hugr root + parent: Option, } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] #[non_exhaustive] /// Errors produced by [UntuplePass]. pub enum UntupleError { - /// An error occurred while validating the rewrite. - ValidationError(ValidatePassError), /// Rewriting the circuit failed. RewriteError(SimpleReplacementError), } /// Result type for the untuple pass. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct UntupleResult { /// Number of `MakeTuple` rewrites applied. pub rewrites_applied: usize, @@ -71,16 +69,16 @@ pub struct UntupleResult { impl UntuplePass { /// Create a new untuple pass with the given configuration. - pub fn new(recursive: UntupleRecursive, validation: ValidationLevel) -> Self { + pub fn new(recursive: UntupleRecursive) -> Self { Self { recursive, - validation, + parent: None, } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; + /// Sets the parent node to optimize (overwrites any previous setting) + pub fn set_parent(mut self, parent: impl Into>) -> Self { + self.parent = parent.into(); self } @@ -90,31 +88,6 @@ impl UntuplePass { self } - /// Run the pass using specified configuration. - pub fn run( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr, parent)) - } - - /// Run the Monomorphization pass. - fn run_no_validate( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - let rewrites = self.find_rewrites(hugr, parent); - let rewrites_applied = rewrites.len(); - // The rewrites are independent, so we can always apply them all. - for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; - } - Ok(UntupleResult { rewrites_applied }) - } - /// Find tuple pack operations followed by tuple unpack operations /// and generate rewrites to remove them. /// @@ -148,6 +121,22 @@ impl UntuplePass { } } +impl ComposablePass for UntuplePass { + type Error = UntupleError; + + type Result = UntupleResult; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); + let rewrites_applied = rewrites.len(); + // The rewrites are independent, so we can always apply them all. + for rewrite in rewrites { + hugr.apply_rewrite(rewrite)?; + } + Ok(UntupleResult { rewrites_applied }) + } +} + /// Returns true if the given optype is a MakeTuple operation. /// /// Boilerplate required due to https://github.com/CQCL/hugr/issues/1496 @@ -421,7 +410,8 @@ mod test { let parent = hugr.root(); let res = pass - .run(&mut hugr, parent) + .set_parent(parent) + .run(&mut hugr) .unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes);