Skip to content

Commit 1121fb0

Browse files
acl-cqcaborgna-q
authored andcommitted
feat!: ComposablePass trait allowing sequencing and validation (#1895)
Currently We have several "passes": monomorphization, dead function removal, constant folding. Each has its own code to allow setting a validation level (before and after that pass). This PR adds the ability chain (sequence) passes;, and to add validation before+after any pass or sequence; and commons up validation code. The top-level `constant_fold_pass` (etc.) functions are left as wrappers that do a single pass with validation only in test. I've left ConstFoldPass as always including DCE, but an alternative could be to return a sequence of the two - ATM that means a tuple `(ConstFoldPass, DeadCodeElimPass)`. I also wondered about including a method `add_entry_point` in ComposablePass (e.g. for ConstFoldPass, that means `with_inputs` but no inputs, i.e. all Top). I feel this is not applicable to *all* passes, but near enough. This could be done in a later PR but `add_entry_point` would need a no-op default for that to be a non-breaking change. So if we wouldn't be happy with the no-op default then I could just add it here... Finally...docs are extremely minimal ATM (this is hugr-passes), I am hoping that most of this is reasonably obvious (it doesn't really do a lot!), but please flag anything you think is particularly in need of a doc comment! BREAKING CHANGE: quite a lot of calls to current pass routines will break, specific cases include (a) `with_validation_level` should be done by wrapping a ValidatingPass around the receiver; (b) XXXPass::run() requires `use ...ComposablePass` (however, such calls will cease to do any validation). closes #1832
1 parent 70881b7 commit 1121fb0

File tree

10 files changed

+550
-265
lines changed

10 files changed

+550
-265
lines changed

hugr-passes/src/composable.rs

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
//! Compiler passes and utilities for composing them
2+
3+
use std::{error::Error, marker::PhantomData};
4+
5+
use hugr_core::hugr::{hugrmut::HugrMut, ValidationError};
6+
use hugr_core::HugrView;
7+
use itertools::Either;
8+
9+
/// An optimization pass that can be sequenced with another and/or wrapped
10+
/// e.g. by [ValidatingPass]
11+
pub trait ComposablePass: Sized {
12+
type Error: Error;
13+
type Result; // Would like to default to () but currently unstable
14+
15+
fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error>;
16+
17+
fn map_err<E2: Error>(
18+
self,
19+
f: impl Fn(Self::Error) -> E2,
20+
) -> impl ComposablePass<Result = Self::Result, Error = E2> {
21+
ErrMapper::new(self, f)
22+
}
23+
24+
/// Returns a [ComposablePass] that does "`self` then `other`", so long as
25+
/// `other::Err` can be combined with ours.
26+
fn then<P: ComposablePass, E: ErrorCombiner<Self::Error, P::Error>>(
27+
self,
28+
other: P,
29+
) -> impl ComposablePass<Result = (Self::Result, P::Result), Error = E> {
30+
struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
31+
impl<E, P1, P2> ComposablePass for Sequence<E, P1, P2>
32+
where
33+
P1: ComposablePass,
34+
P2: ComposablePass,
35+
E: ErrorCombiner<P1::Error, P2::Error>,
36+
{
37+
type Error = E;
38+
39+
type Result = (P1::Result, P2::Result);
40+
41+
fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error> {
42+
let res1 = self.0.run(hugr).map_err(E::from_first)?;
43+
let res2 = self.1.run(hugr).map_err(E::from_second)?;
44+
Ok((res1, res2))
45+
}
46+
}
47+
48+
Sequence(self, other, PhantomData)
49+
}
50+
}
51+
52+
/// Trait for combining the error types from two different passes
53+
/// into a single error.
54+
pub trait ErrorCombiner<A, B>: Error {
55+
fn from_first(a: A) -> Self;
56+
fn from_second(b: B) -> Self;
57+
}
58+
59+
impl<A: Error, B: Into<A>> ErrorCombiner<A, B> for A {
60+
fn from_first(a: A) -> Self {
61+
a
62+
}
63+
64+
fn from_second(b: B) -> Self {
65+
b.into()
66+
}
67+
}
68+
69+
impl<A: Error, B: Error> ErrorCombiner<A, B> for Either<A, B> {
70+
fn from_first(a: A) -> Self {
71+
Either::Left(a)
72+
}
73+
74+
fn from_second(b: B) -> Self {
75+
Either::Right(b)
76+
}
77+
}
78+
79+
// Note: in the short term we could wish for two more impls:
80+
// impl<E:Error> ErrorCombiner<Infallible, E> for E
81+
// impl<E:Error> ErrorCombiner<E, Infallible> for E
82+
// however, these aren't possible as they conflict with
83+
// impl<A, B:Into<A>> ErrorCombiner<A,B> for A
84+
// when A=E=Infallible, boo :-(.
85+
// However this will become possible, indeed automatic, when Infallible is replaced
86+
// by ! (never_type) as (unlike Infallible) ! converts Into anything
87+
88+
// ErrMapper ------------------------------
89+
struct ErrMapper<P, E, F>(P, F, PhantomData<E>);
90+
91+
impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, E, F> {
92+
fn new(pass: P, err_fn: F) -> Self {
93+
Self(pass, err_fn, PhantomData)
94+
}
95+
}
96+
97+
impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ComposablePass for ErrMapper<P, E, F> {
98+
type Error = E;
99+
type Result = P::Result;
100+
101+
fn run(&self, hugr: &mut impl HugrMut) -> Result<P::Result, Self::Error> {
102+
self.0.run(hugr).map_err(&self.1)
103+
}
104+
}
105+
106+
// ValidatingPass ------------------------------
107+
108+
/// Error from a [ValidatingPass]
109+
#[derive(thiserror::Error, Debug)]
110+
pub enum ValidatePassError<E> {
111+
#[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
112+
Input {
113+
#[source]
114+
err: ValidationError,
115+
pretty_hugr: String,
116+
},
117+
#[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
118+
Output {
119+
#[source]
120+
err: ValidationError,
121+
pretty_hugr: String,
122+
},
123+
#[error(transparent)]
124+
Underlying(#[from] E),
125+
}
126+
127+
/// Runs an underlying pass, but with validation of the Hugr
128+
/// both before and afterwards.
129+
pub struct ValidatingPass<P>(P, bool);
130+
131+
impl<P: ComposablePass> ValidatingPass<P> {
132+
pub fn new_default(underlying: P) -> Self {
133+
// Self(underlying, cfg!(feature = "extension_inference"))
134+
// Sadly, many tests fail with extension inference, hence:
135+
Self(underlying, false)
136+
}
137+
138+
pub fn new_validating_extensions(underlying: P) -> Self {
139+
Self(underlying, true)
140+
}
141+
142+
pub fn new(underlying: P, validate_extensions: bool) -> Self {
143+
Self(underlying, validate_extensions)
144+
}
145+
146+
fn validation_impl<E>(
147+
&self,
148+
hugr: &impl HugrView,
149+
mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError<E>,
150+
) -> Result<(), ValidatePassError<E>> {
151+
match self.1 {
152+
false => hugr.validate_no_extensions(),
153+
true => hugr.validate(),
154+
}
155+
.map_err(|err| mk_err(err, hugr.mermaid_string()))
156+
}
157+
}
158+
159+
impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
160+
type Error = ValidatePassError<P::Error>;
161+
type Result = P::Result;
162+
163+
fn run(&self, hugr: &mut impl HugrMut) -> Result<P::Result, Self::Error> {
164+
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
165+
err,
166+
pretty_hugr,
167+
})?;
168+
let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
169+
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
170+
err,
171+
pretty_hugr,
172+
})?;
173+
Ok(res)
174+
}
175+
}
176+
177+
// IfThen ------------------------------
178+
/// [ComposablePass] that executes a first pass that returns a `bool`
179+
/// result; and then, if-and-only-if that first result was true,
180+
/// executes a second pass
181+
pub struct IfThen<E, A, B>(A, B, PhantomData<E>);
182+
183+
impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Error, B::Error>>
184+
IfThen<E, A, B>
185+
{
186+
/// Make a new instance given the [ComposablePass] to run first
187+
/// and (maybe) second
188+
pub fn new(fst: A, opt_snd: B) -> Self {
189+
Self(fst, opt_snd, PhantomData)
190+
}
191+
}
192+
193+
impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Error, B::Error>>
194+
ComposablePass for IfThen<E, A, B>
195+
{
196+
type Error = E;
197+
198+
type Result = Option<B::Result>;
199+
200+
fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error> {
201+
let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
202+
res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
203+
.transpose()
204+
}
205+
}
206+
207+
pub(crate) fn validate_if_test<P: ComposablePass>(
208+
pass: P,
209+
hugr: &mut impl HugrMut,
210+
) -> Result<P::Result, ValidatePassError<P::Error>> {
211+
if cfg!(test) {
212+
ValidatingPass::new_default(pass).run(hugr)
213+
} else {
214+
pass.run(hugr).map_err(ValidatePassError::Underlying)
215+
}
216+
}
217+
218+
#[cfg(test)]
219+
mod test {
220+
use itertools::{Either, Itertools};
221+
use std::convert::Infallible;
222+
223+
use hugr_core::builder::{
224+
Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
225+
ModuleBuilder,
226+
};
227+
use hugr_core::extension::prelude::{
228+
bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID,
229+
};
230+
use hugr_core::hugr::hugrmut::HugrMut;
231+
use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG};
232+
use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
233+
use hugr_core::types::{Signature, TypeRow};
234+
use hugr_core::{Hugr, HugrView, IncomingPort};
235+
236+
use crate::const_fold::{ConstFoldError, ConstantFoldPass};
237+
use crate::untuple::{UntupleRecursive, UntupleResult};
238+
use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};
239+
240+
use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass};
241+
242+
#[test]
243+
fn test_then() {
244+
let mut mb = ModuleBuilder::new();
245+
let id1 = mb
246+
.define_function("id1", Signature::new_endo(usize_t()))
247+
.unwrap();
248+
let inps = id1.input_wires();
249+
let id1 = id1.finish_with_outputs(inps).unwrap();
250+
let id2 = mb
251+
.define_function("id2", Signature::new_endo(usize_t()))
252+
.unwrap();
253+
let inps = id2.input_wires();
254+
let id2 = id2.finish_with_outputs(inps).unwrap();
255+
let hugr = mb.finish_hugr().unwrap();
256+
257+
let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
258+
let cfold =
259+
ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);
260+
261+
cfold.run(&mut hugr.clone()).unwrap();
262+
263+
let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE);
264+
let r: Result<_, Either<Infallible, ConstFoldError>> =
265+
dce.clone().then(cfold.clone()).run(&mut hugr.clone());
266+
assert_eq!(r, Err(Either::Right(exp_err.clone())));
267+
268+
let r = dce
269+
.clone()
270+
.map_err(|inf| match inf {})
271+
.then(cfold.clone())
272+
.run(&mut hugr.clone());
273+
assert_eq!(r, Err(exp_err));
274+
275+
let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone());
276+
r2.unwrap();
277+
}
278+
279+
#[test]
280+
fn test_validation() {
281+
let mut h = Hugr::new(DFG {
282+
signature: Signature::new(usize_t(), bool_t()),
283+
});
284+
let inp = h.add_node_with_parent(
285+
h.root(),
286+
Input {
287+
types: usize_t().into(),
288+
},
289+
);
290+
let outp = h.add_node_with_parent(
291+
h.root(),
292+
Output {
293+
types: bool_t().into(),
294+
},
295+
);
296+
h.connect(inp, 0, outp, 0);
297+
let backup = h.clone();
298+
let err = backup.validate().unwrap_err();
299+
300+
let no_inputs: [(IncomingPort, _); 0] = [];
301+
let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs);
302+
cfold.run(&mut h).unwrap();
303+
assert_eq!(h, backup); // Did nothing
304+
305+
let r = ValidatingPass(cfold, false).run(&mut h);
306+
assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err));
307+
}
308+
309+
#[test]
310+
fn test_if_then() {
311+
let tr = TypeRow::from(vec![usize_t(); 2]);
312+
313+
let h = {
314+
let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID);
315+
let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap();
316+
let [a, b] = fb.input_wires_arr();
317+
let tup = fb
318+
.add_dataflow_op(MakeTuple::new(tr.clone()), [a, b])
319+
.unwrap();
320+
let untup = fb
321+
.add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs())
322+
.unwrap();
323+
fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
324+
};
325+
326+
let untup = UntuplePass::new(UntupleRecursive::Recursive);
327+
{
328+
// Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple
329+
let mut repl = ReplaceTypes::default();
330+
let usize_custom_t = usize_t().as_extension().unwrap().clone();
331+
repl.replace_type(usize_custom_t, INT_TYPES[6].clone());
332+
let ifthen = IfThen::<Either<_, _>, _, _>::new(repl, untup.clone());
333+
334+
let mut h = h.clone();
335+
let r = validate_if_test(ifthen, &mut h).unwrap();
336+
assert_eq!(
337+
r,
338+
Some(UntupleResult {
339+
rewrites_applied: 1
340+
})
341+
);
342+
let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap();
343+
assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]);
344+
}
345+
346+
// Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple
347+
let mut repl = ReplaceTypes::default();
348+
let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone();
349+
repl.replace_type(i32_custom_t, INT_TYPES[6].clone());
350+
let ifthen = IfThen::<Either<_, _>, _, _>::new(repl, untup);
351+
let mut h = h;
352+
let r = validate_if_test(ifthen, &mut h).unwrap();
353+
assert_eq!(r, None);
354+
assert_eq!(h.children(h.root()).count(), 4);
355+
let mktup = h
356+
.output_neighbours(h.first_child(h.root()).unwrap())
357+
.next()
358+
.unwrap();
359+
assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr)));
360+
}
361+
}

0 commit comments

Comments
 (0)