From c328a181cc9880c6986eb51cf95fbdd6e78b681c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 14:32:43 +0000 Subject: [PATCH 01/61] wip --- hugr-core/src/hugr/hugrmut.rs | 45 +++++- hugr-core/src/hugr/internal.rs | 17 ++- hugr-core/src/types.rs | 20 ++- hugr-passes/src/non_local.rs | 241 ++++++++++++++++++++++++++++++++- 4 files changed, 316 insertions(+), 7 deletions(-) diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 4056f36e61..b76d1897fe 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -4,6 +4,7 @@ use core::panic; use std::collections::HashMap; use std::sync::Arc; +use itertools::Itertools as _; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap}; @@ -11,7 +12,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; -use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Direction, Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; use super::NodeMetadataMap; @@ -278,6 +279,48 @@ pub trait HugrMut: HugrMutInternals { fn extensions_mut(&mut self) -> &mut ExtensionRegistry { &mut self.hugr_mut().extensions } + + /// TODO perhaps these should be on HugrMut? + fn insert_incoming_port(&mut self, node: Node, index: usize) -> IncomingPort { + let _ = self + .add_ports(node, Direction::Incoming, 1) + .exactly_one() + .unwrap(); + + for (to, from) in (index..self.num_inputs(node)) + .map_into::() + .rev() + .tuple_windows() + { + let linked_outputs = self.linked_outputs(node, from).collect_vec(); + self.disconnect(node, from); + for (linked_node, linked_port) in linked_outputs { + self.connect(linked_node, linked_port, node, to); + } + } + index.into() + } + + /// TODO perhaps these should be on HugrMut? + fn insert_outgoing_port(&mut self, node: Node, index: usize) -> OutgoingPort { + let _ = self + .add_ports(node, Direction::Outgoing, 1) + .exactly_one() + .unwrap(); + + for (to, from) in (index..self.num_outputs(node)) + .map_into::() + .rev() + .tuple_windows() + { + let linked_inputs = self.linked_inputs(node, from).collect_vec(); + self.disconnect(node, from); + for (linked_node, linked_port) in linked_inputs { + self.connect(node, to, linked_node, linked_port); + } + } + index.into() + } } /// Records the result of inserting a Hugr or view diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 3f1c6b6ff7..30f8727161 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -245,6 +245,13 @@ pub trait HugrMutInternals: RootTagged { } self.hugr_mut().replace_op(node, op) } + + /// TODO docs + fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { + panic_invalid_node(self, node); + // TODO refuse if node == self.root() because tag might be violated + self.hugr_mut().get_optype_mut(node) + } } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -305,7 +312,13 @@ impl + AsMut> HugrMutInternals for T { fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check - let cur = self.hugr_mut().op_types.get_mut(node.pg_index()); - Ok(std::mem::replace(cur, op.into())) + Ok(std::mem::replace( + self.hugr_mut().get_optype_mut(node)?, + op.into(), + )) + } + + fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { + Ok(self.hugr_mut().op_types.get_mut(node.pg_index())) } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 962e00876c..025666eb0a 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -27,7 +27,7 @@ pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; use itertools::FoldWhile::{Continue, Done}; -use itertools::{repeat_n, Itertools}; +use itertools::{repeat_n, Either, Itertools}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; @@ -256,6 +256,16 @@ impl SumType { _ => None, } } + + /// TODO docs + pub fn iter_variants(&self) -> impl Iterator { + match self { + SumType::Unit { size } => { + Either::Left(repeat_n(TypeRV::EMPTY_TYPEROW_REF, *size as usize)) + } + SumType::General { rows } => Either::Right(rows.iter()), + } + } } impl From for TypeBase { @@ -453,6 +463,14 @@ impl TypeBase { &mut self.0 } + /// TODO docs + pub fn as_sum_type(&self) -> Option<&SumType> { + match &self.0 { + TypeEnum::Sum(s) => Some(s), + _ => None, + } + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index efb5e7139e..5b31de9b13 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,11 +1,18 @@ //! This module provides functions for inspecting and modifying the nature of //! non local edges in a Hugr. +use ascent::hashbrown::HashMap; // //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions -use itertools::Itertools as _; +use itertools::{Either, Itertools as _}; use thiserror::Error; -use hugr_core::{HugrView, IncomingPort, Node}; +use hugr_core::{ + builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, + hugr::hugrmut::HugrMut, + ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, + types::{EdgeKind, Type, TypeRow}, + HugrView, IncomingPort, Node, PortIndex, Wire, +}; /// Returns an iterator over all non local edges in a Hugr. /// @@ -38,12 +45,178 @@ pub fn ensure_no_nonlocal_edges(hugr: &impl HugrView) -> Result<(), NonLocalEdge } } +#[derive(Debug, Clone)] +struct WorkItem { + source: Wire, + target: (Node, IncomingPort), + ty: Type, +} + +fn thread_dataflow_parent( + hugr: &mut impl HugrMut, + parent: Node, + port_index: usize, + ty: Type, +) -> Wire { + let [i, _] = hugr.get_io(parent).unwrap(); + let OpType::Input(mut input) = hugr.get_optype(i).clone() else { + panic!("impossible") + }; + input.types.to_mut().insert(port_index, ty); + hugr.replace_op(i, input).unwrap(); + let input_port = hugr.insert_outgoing_port(i, port_index); + Wire::new(i, input_port) +} + +fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> (WorkItem, Wire) { + let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); + let new_port_index = tailloop.just_inputs.len(); + tailloop.just_inputs.to_mut().push(ty.clone()); + hugr.replace_op(node, tailloop).unwrap(); + let tailloop_port = hugr.insert_incoming_port(node, new_port_index); + hugr.connect(source.node(), source.source(), node, tailloop_port); + let workitem = WorkItem { + source, + target: (node, tailloop_port), + ty: ty.clone(), + }; + + let input_wire = thread_dataflow_parent(hugr, node, tailloop_port.index(), ty.clone()); + + let [_, o] = hugr.get_io(node).unwrap(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = + hugr.get_optype(o).port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum_type() else { + panic!("impossible") + }; + + let old_sum_rows: Vec = sum_type + .iter_variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows = { + let mut v = old_sum_rows.clone(); + v[TailLoop::CONTINUE_TAG].to_mut().push(ty.clone()); + v + }; + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = + ConditionalBuilder::new(old_sum_rows, ty.clone(), new_control_type.clone()).unwrap(); + for i in 0..2 { + let mut case = cond.case_builder(i).unwrap(); + let inputs = { + let all_inputs = case.input_wires(); + if i == TailLoop::CONTINUE_TAG { + Either::Left(all_inputs) + } else { + Either::Right(all_inputs.into_iter().dropping_back(1)) + } + }; + + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(node, cond).new_root; + let (n, p) = hugr.single_linked_output(o, 0).unwrap(); + hugr.connect(n, p, cond_node, 0); + hugr.connect(input_wire.node(), input_wire.source(), cond_node, 1); + hugr.disconnect(o, IncomingPort::from(0)); + hugr.connect(cond_node, 0, o, 0); + let mut output = hugr.get_optype(o).as_output().unwrap().clone(); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(o, output).unwrap(); + (workitem, input_wire) +} + +pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { + let mut non_local_edges = nonlocal_edges(hugr) + .map(|target @ (node, inport)| { + let source = { + let (n, p) = hugr.single_linked_output(node, inport).unwrap(); + Wire::new(n, p) + }; + debug_assert!( + hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() + ); + let Some(EdgeKind::Value(ty)) = hugr + .get_optype(hugr.get_parent(source.node()).unwrap()) + .port_kind(source.source()) + else { + panic!("impossible") + }; + WorkItem { source, target, ty } + }) + .collect_vec(); + + if non_local_edges.is_empty() { + return Ok(()); + } + + let mut parent_source_map = HashMap::new(); + + while let Some(WorkItem { source, target, ty }) = non_local_edges.pop() { + dbg!(&source, target, &ty); + let parent = hugr.get_parent(target.0).unwrap(); + let local_source = if hugr.get_parent(source.node()).unwrap() == parent { + &source + } else { + parent_source_map + .entry((parent, source)) + .or_insert_with(|| { + let (workitem, wire) = match hugr.get_optype(parent).clone() { + OpType::DFG(mut dfg) => { + let new_port_index = dfg.signature.input.len(); + dbg!(&dfg, new_port_index); + dfg.signature.input.to_mut().push(ty.clone()); + hugr.replace_op(parent, dfg).unwrap(); + let dfg_port = hugr.insert_incoming_port(parent, new_port_index); + hugr.connect(source.node(), source.source(), parent, dfg_port); + ( + WorkItem { + source, + target: (parent, dfg_port), + ty: ty.clone(), + }, + thread_dataflow_parent(hugr, parent, dfg_port.index(), ty), + ) + } + OpType::DataflowBlock(dataflow_block) => todo!(), + OpType::TailLoop(_) => do_tailloop(hugr, parent, source, ty), + OpType::Case(case) => todo!(), + _ => panic!("impossible"), + }; + non_local_edges.push(workitem); + wire + }) + }; + hugr.disconnect(target.0, target.1); + hugr.connect( + local_source.node(), + local_source.source(), + target.0, + target.1, + ); + } + + Ok(()) +} + #[cfg(test)] mod test { use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, extension::prelude::{bool_t, Noop}, - ops::handle::NodeHandle, + ops::{handle::NodeHandle, Tag, TailLoop}, type_row, types::Signature, }; @@ -94,4 +267,66 @@ mod test { NonLocalEdgesError::Edges(vec![edge]) ); } + + #[test] + fn remove_nonlocal_edges_dfg() { + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); + let [w0] = outer.input_wires_arr(); + let [w1] = { + let inner = outer + .dfg_builder(Signature::new(type_row![], bool_t()), []) + .unwrap(); + inner.finish_with_outputs([w0]).unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([w1]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap(); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } + + #[test] + fn remove_nonlocal_edges_tailloop() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new_endo(vec![ + t1.clone(), + t2.clone(), + t3.clone(), + ])) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [s2, s3] = { + let mut inner = outer + .tail_loop_builder( + [(t1.clone(), s1)], + [(t3.clone(), s3)], + vec![t2.clone()].into(), + ) + .unwrap(); + let [_s1, s3] = inner.input_wires_arr(); + let control = inner + .add_dataflow_op( + Tag::new( + TailLoop::BREAK_TAG, + vec![vec![t1.clone()].into(), vec![t2.clone()].into()], + ), + [s2], + ) + .unwrap() + .out_wire(0); + inner + .finish_with_outputs(control, [s3]) + .unwrap() + .outputs_arr() + }; + outer.finish_hugr_with_outputs([s1, s2, s3]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } } From cd2348d576a7d45ff576677a3a7d262273f100d3 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 14:56:08 +0000 Subject: [PATCH 02/61] add pass --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/non_local.rs | 37 ++++++++++++++++++++++++++++++++--- hugr-passes/src/validation.rs | 2 +- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index c02d3d8591..39162a52e6 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -24,6 +24,7 @@ lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } +derive_more = { workspace = true, features = ["from", "error", "display"] } [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 5b31de9b13..9e5af7cb85 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -4,7 +4,6 @@ use ascent::hashbrown::HashMap; // //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions use itertools::{Either, Itertools as _}; -use thiserror::Error; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -14,6 +13,34 @@ use hugr_core::{ HugrView, IncomingPort, Node, PortIndex, Wire, }; +use crate::validation::{ValidatePassError, ValidationLevel}; + +/// TODO docs +#[derive(Debug, Clone, Default)] +pub struct UnNonLocalPass { + validation: ValidationLevel, +} + +impl UnNonLocalPass { + /// Sets the validation level used before and after the pass is run. + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Run the Monomorphization pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { + remove_nonlocal_edges(hugr)?; + Ok(()) + } + + /// Run the pass using specified configuration. + pub fn run(&self, hugr: &mut H) -> Result<(), NonLocalEdgesError> { + self.validation + .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) + } +} + /// Returns an iterator over all non local edges in a Hugr. /// /// All `(node, in_port)` pairs are returned where `in_port` is a value port @@ -29,10 +56,14 @@ pub fn nonlocal_edges(hugr: &impl HugrView) -> impl Iterator), + #[from] + ValidationError(ValidatePassError), } /// Verifies that there are no non local value edges in the Hugr. diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index baf3b86d83..5f53f403c7 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -23,7 +23,7 @@ pub enum ValidationLevel { WithExtensions, } -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] From e04390ebefacec82b4ea1ad09fffe9285a44bad8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 15:01:07 +0000 Subject: [PATCH 03/61] oops --- hugr-passes/src/non_local.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 9e5af7cb85..3f588d9e2d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -180,7 +180,7 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() ); let Some(EdgeKind::Value(ty)) = hugr - .get_optype(hugr.get_parent(source.node()).unwrap()) + .get_optype(source.node()) .port_kind(source.source()) else { panic!("impossible") From 903acc2eb847b4254226adf140a8cf4b5ea80bd7 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 15:42:54 +0000 Subject: [PATCH 04/61] conditional --- hugr-passes/src/non_local.rs | 129 +++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 3f588d9e2d..455b9e90fd 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -199,36 +199,62 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge dbg!(&source, target, &ty); let parent = hugr.get_parent(target.0).unwrap(); let local_source = if hugr.get_parent(source.node()).unwrap() == parent { - &source + source + } else if let Some(wire) = parent_source_map.get(&(parent,source)) { + *wire } else { - parent_source_map - .entry((parent, source)) - .or_insert_with(|| { - let (workitem, wire) = match hugr.get_optype(parent).clone() { - OpType::DFG(mut dfg) => { - let new_port_index = dfg.signature.input.len(); - dbg!(&dfg, new_port_index); - dfg.signature.input.to_mut().push(ty.clone()); - hugr.replace_op(parent, dfg).unwrap(); - let dfg_port = hugr.insert_incoming_port(parent, new_port_index); - hugr.connect(source.node(), source.source(), parent, dfg_port); - ( - WorkItem { - source, - target: (parent, dfg_port), - ty: ty.clone(), - }, - thread_dataflow_parent(hugr, parent, dfg_port.index(), ty), - ) + let (workitem, wire) = match hugr.get_optype(parent).clone() { + OpType::DFG(mut dfg) => { + let new_port_index = dfg.signature.input.len(); + dbg!(&dfg, new_port_index); + dfg.signature.input.to_mut().push(ty.clone()); + hugr.replace_op(parent, dfg).unwrap(); + let dfg_port = hugr.insert_incoming_port(parent, new_port_index); + hugr.connect(source.node(), source.source(), parent, dfg_port); + let wire = thread_dataflow_parent(hugr, parent, dfg_port.index(), ty.clone()); + let _ = parent_source_map.insert((parent, source), wire); + ( + WorkItem { + source, + target: (parent, dfg_port), + ty + }, + wire + ) + } + OpType::DataflowBlock(dataflow_block) => todo!(), + OpType::TailLoop(_) => { + let (workitem, wire) = do_tailloop(hugr, parent, source, ty); + let _ = parent_source_map.insert((parent, source), wire); + (workitem, wire) + } + OpType::Case(_) => { + let cond_node = hugr.get_parent(parent).unwrap(); + let mut cond = hugr.get_optype(cond_node).as_conditional().unwrap().clone(); + let new_port_index = cond.signature().input().len(); + cond.other_inputs.to_mut().push(ty.clone()); + hugr.replace_op(cond_node, cond).unwrap(); + let cond_port = hugr.insert_incoming_port(cond_node, new_port_index); + let mut this_wire = None; + for (case_n, mut case) in hugr.children(cond_node).filter_map(|n| { + let case = hugr.get_optype(n).as_case()?; + Some((n, case.clone())) + }).collect_vec() { + let case_port_index = case.signature.input().len(); + case.signature.input.to_mut().push(ty.clone()); + hugr.replace_op(case_n, case).unwrap(); + let case_input_wire = thread_dataflow_parent(hugr, case_n, case_port_index, ty.clone()); + let _ = parent_source_map.insert((case_n, source), case_input_wire); + if case_n == parent { + this_wire = Some(case_input_wire); } - OpType::DataflowBlock(dataflow_block) => todo!(), - OpType::TailLoop(_) => do_tailloop(hugr, parent, source, ty), - OpType::Case(case) => todo!(), - _ => panic!("impossible"), - }; - non_local_edges.push(workitem); - wire - }) + } + (WorkItem { source, target: (cond_node, cond_port), ty }, this_wire.unwrap()) + } + _ => panic!("impossible"), + }; + non_local_edges.push(workitem); + wire }; hugr.disconnect(target.0, target.1); hugr.connect( @@ -245,9 +271,9 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge #[cfg(test)] mod test { use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, + builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, extension::prelude::{bool_t, Noop}, - ops::{handle::NodeHandle, Tag, TailLoop}, + ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, types::Signature, }; @@ -360,4 +386,47 @@ mod test { hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } + + #[test] + fn remove_nonlocal_edges_cond() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + let out_variants = vec![t1.clone().into(), t2.clone().into()]; + let out_type = Type::new_sum(out_variants.clone()); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new(vec![ + t1.clone(), + t2.clone(), + t3.clone() + ], out_type.clone())) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [out] = { + let mut cond = outer + .conditional_builder((vec![type_row![];3], s3), [], out_type.into()).unwrap(); + + { + let mut case = cond.case_builder(0).unwrap(); + let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]).unwrap().outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + { + let mut case = cond.case_builder(1).unwrap(); + let [r] = case.add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]).unwrap().outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + { + let mut case = cond.case_builder(2).unwrap(); + let u = case.add_load_value(Value::unit()); + let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [u]).unwrap().outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + cond.finish_sub_container().unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([out]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } } From a87d705f5156a13c0559075d02ab3308bbd0c19c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 15:44:58 +0000 Subject: [PATCH 05/61] remove dbg --- hugr-passes/src/non_local.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 455b9e90fd..1bfa764374 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -196,7 +196,6 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge let mut parent_source_map = HashMap::new(); while let Some(WorkItem { source, target, ty }) = non_local_edges.pop() { - dbg!(&source, target, &ty); let parent = hugr.get_parent(target.0).unwrap(); let local_source = if hugr.get_parent(source.node()).unwrap() == parent { source @@ -206,7 +205,6 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge let (workitem, wire) = match hugr.get_optype(parent).clone() { OpType::DFG(mut dfg) => { let new_port_index = dfg.signature.input.len(); - dbg!(&dfg, new_port_index); dfg.signature.input.to_mut().push(ty.clone()); hugr.replace_op(parent, dfg).unwrap(); let dfg_port = hugr.insert_incoming_port(parent, new_port_index); From 11879a8efdb863dedb357d5e0fbfaca19b5a7875 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Feb 2025 13:55:02 +0000 Subject: [PATCH 06/61] refactor --- hugr-passes/src/non_local.rs | 481 +++++++++++++++++++++++++---------- 1 file changed, 353 insertions(+), 128 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 1bfa764374..26584d5a34 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,16 +1,23 @@ //! This module provides functions for inspecting and modifying the nature of //! non local edges in a Hugr. -use ascent::hashbrown::HashMap; -// +use std::{ + collections::{BTreeMap, HashMap, HashSet, VecDeque}, + iter, +}; + //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions use itertools::{Either, Itertools as _}; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, - hugr::hugrmut::HugrMut, + hugr::{ + hugrmut::HugrMut, + views::{DescendantsGraph, HierarchyView}, + HugrError, + }, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, - HugrView, IncomingPort, Node, PortIndex, Wire, + HugrView, IncomingPort, Node, Wire, }; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -21,7 +28,7 @@ pub struct UnNonLocalPass { validation: ValidationLevel, } -impl UnNonLocalPass { +impl UnNonLocalPass { /// Sets the validation level used before and after the pass is run. pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; @@ -30,7 +37,8 @@ impl UnNonLocalPass { /// Run the Monomorphization pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { - remove_nonlocal_edges(hugr)?; + let root = hugr.root(); + remove_nonlocal_edges(hugr, root)?; Ok(()) } @@ -64,6 +72,8 @@ pub enum NonLocalEdgesError { Edges(Vec<(Node, IncomingPort)>), #[from] ValidationError(ValidatePassError), + #[from] + HugrError(HugrError), } /// Verifies that there are no non local value edges in the Hugr. @@ -86,33 +96,51 @@ struct WorkItem { fn thread_dataflow_parent( hugr: &mut impl HugrMut, parent: Node, - port_index: usize, - ty: Type, -) -> Wire { - let [i, _] = hugr.get_io(parent).unwrap(); - let OpType::Input(mut input) = hugr.get_optype(i).clone() else { + start_port_index: usize, + types: Vec, +) -> impl Iterator { + let [input_n, _] = hugr.get_io(parent).unwrap(); + let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { panic!("impossible") }; - input.types.to_mut().insert(port_index, ty); - hugr.replace_op(i, input).unwrap(); - let input_port = hugr.insert_outgoing_port(i, port_index); - Wire::new(i, input_port) + let mut r = vec![]; + for (i, ty) in types.into_iter().enumerate() { + input + .types + .to_mut() + .insert(start_port_index + i, ty.clone()); + r.push(Wire::new( + input_n, + hugr.insert_outgoing_port(input_n, start_port_index + i), + )); + } + hugr.replace_op(input_n, input).unwrap(); + r.into_iter() } -fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> (WorkItem, Wire) { +fn do_tailloop( + parent_source_map: &mut ParentSourceMap, + hugr: &mut impl HugrMut, + node: Node, + sources: impl IntoIterator, +) -> impl Iterator { + let (sources, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); - let new_port_index = tailloop.just_inputs.len(); - tailloop.just_inputs.to_mut().push(ty.clone()); - hugr.replace_op(node, tailloop).unwrap(); - let tailloop_port = hugr.insert_incoming_port(node, new_port_index); - hugr.connect(source.node(), source.source(), node, tailloop_port); - let workitem = WorkItem { - source, - target: (node, tailloop_port), - ty: ty.clone(), - }; + let start_port_index = tailloop.just_inputs.len(); + { + tailloop.just_inputs.to_mut().extend(types.iter().cloned()); + hugr.replace_op(node, tailloop).unwrap(); + } + let tailloop_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) + .collect_vec(); - let input_wire = thread_dataflow_parent(hugr, node, tailloop_port.index(), ty.clone()); + let input_wires = + thread_dataflow_parent(hugr, node, start_port_index, types.clone()).collect_vec(); + parent_source_map.insert( + node, + iter::zip(sources.iter().copied(), input_wires.iter().copied()), + ); let [_, o] = hugr.get_io(node).unwrap(); let (cond, new_control_type) = { @@ -131,13 +159,15 @@ fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> ( .collect_vec(); let new_sum_rows = { let mut v = old_sum_rows.clone(); - v[TailLoop::CONTINUE_TAG].to_mut().push(ty.clone()); + v[TailLoop::CONTINUE_TAG] + .to_mut() + .extend(types.iter().cloned()); v }; let new_control_type = Type::new_sum(new_sum_rows.clone()); let mut cond = - ConditionalBuilder::new(old_sum_rows, ty.clone(), new_control_type.clone()).unwrap(); + ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()).unwrap(); for i in 0..2 { let mut case = cond.case_builder(i).unwrap(); let inputs = { @@ -145,7 +175,7 @@ fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> ( if i == TailLoop::CONTINUE_TAG { Either::Left(all_inputs) } else { - Either::Right(all_inputs.into_iter().dropping_back(1)) + Either::Right(all_inputs.into_iter().dropping_back(types.len())) } }; @@ -160,107 +190,290 @@ fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> ( let cond_node = hugr.insert_hugr(node, cond).new_root; let (n, p) = hugr.single_linked_output(o, 0).unwrap(); hugr.connect(n, p, cond_node, 0); - hugr.connect(input_wire.node(), input_wire.source(), cond_node, 1); + for (i, w) in input_wires.into_iter().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } hugr.disconnect(o, IncomingPort::from(0)); hugr.connect(cond_node, 0, o, 0); let mut output = hugr.get_optype(o).as_output().unwrap().clone(); output.types.to_mut()[0] = new_control_type; hugr.replace_op(o, output).unwrap(); - (workitem, input_wire) + mk_workitems(node, sources, tailloop_ports, types) } -pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { - let mut non_local_edges = nonlocal_edges(hugr) - .map(|target @ (node, inport)| { - let source = { - let (n, p) = hugr.single_linked_output(node, inport).unwrap(); - Wire::new(n, p) - }; - debug_assert!( - hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() +#[derive(Clone, Default, Debug)] +struct ParentSourceMap(HashMap>); + +impl ParentSourceMap { + fn contains_parent(&self, parent: Node) -> bool { + self.0.contains_key(&parent) + } + + fn insert(&mut self, parent: Node, sources: impl IntoIterator) { + debug_assert!(!self.0.contains_key(&parent)); + self.0.entry(parent).or_default().extend(sources); + } + + fn get(&self, parent: Node, source: Wire) -> Option { + self.0.get(&parent).and_then(|m| m.get(&source).cloned()) + } +} + +fn mk_workitems( + node: Node, + sources: impl IntoIterator, + ports: impl IntoIterator, + types: impl IntoIterator, +) -> impl Iterator { + itertools::izip!(sources, ports, types).map(move |(source, p, ty)| WorkItem { + source, + target: (node, p), + ty, + }) +} + +fn thread_sources( + parent_source_map: &mut ParentSourceMap, + hugr: &mut impl HugrMut, + bb: Node, + sources: impl IntoIterator, +) -> Vec { + let (source_wires, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); + match hugr.get_optype(bb).clone() { + OpType::DFG(mut dfg) => { + debug_assert!(!parent_source_map.contains_parent(bb)); + let start_new_port_index = dfg.signature.input().len(); + let new_dfg_ports = (0..source_wires.len()) + .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) + .collect_vec(); + dfg.signature.input.to_mut().extend(types.clone()); + hugr.replace_op(bb, dfg).unwrap(); + for (source, &target) in iter::zip(source_wires.iter(), new_dfg_ports.iter()) { + hugr.connect(source.node(), source.source(), bb, target); + } + parent_source_map.insert( + bb, + iter::zip( + source_wires.iter().copied(), + thread_dataflow_parent(hugr, bb, start_new_port_index, types.clone()), + ), ); - let Some(EdgeKind::Value(ty)) = hugr - .get_optype(source.node()) - .port_kind(source.source()) - else { - panic!("impossible") - }; - WorkItem { source, target, ty } - }) - .collect_vec(); + mk_workitems(bb, source_wires, new_dfg_ports, types).collect_vec() + } + OpType::Conditional(mut cond) => { + debug_assert!(!parent_source_map.contains_parent(bb)); + let start_new_port_index = cond.signature().input().len(); + cond.other_inputs.to_mut().extend(types.clone()); + hugr.replace_op(bb, cond).unwrap(); + let new_cond_ports = (0..source_wires.len()) + .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) + .collect_vec(); + parent_source_map.insert(bb, iter::empty()); + mk_workitems(bb, source_wires, new_cond_ports, types).collect_vec() + } + OpType::Case(mut case) => { + debug_assert!(!parent_source_map.contains_parent(bb)); + let start_case_port_index = case.signature.input().len(); + case.signature.input.to_mut().extend(types.clone()); + hugr.replace_op(bb, case).unwrap(); + parent_source_map.insert( + bb, + iter::zip( + source_wires.iter().copied(), + thread_dataflow_parent(hugr, bb, start_case_port_index, types), + ), + ); + vec![] + } + OpType::TailLoop(_) => { + do_tailloop(parent_source_map, hugr, bb, iter::zip(source_wires, types)).collect_vec() + } + _ => panic!("impossible"), + } + // _ => panic!("impossible"), + // }; + // non_local_edges.push(workitem); + // wire + // }; + // hugr.disconnect(target.0, target.1); + // hugr.connect( + // local_source.node(), + // local_source.source(), + // target.0, + // target.1, + // ); + // } +} - if non_local_edges.is_empty() { - return Ok(()); +#[derive(Debug, Default, Clone)] +struct BBNeedsSourcesMapBuilder(HashMap>); + +impl BBNeedsSourcesMapBuilder { + fn insert(&mut self, bb: Node, source: Wire, ty: Type) { + self.0.entry(bb).or_default().insert(source, ty); } - let mut parent_source_map = HashMap::new(); + fn extend_parent_needs_for(&mut self, ref hugr: impl HugrView, child: Node) -> bool { + let parent = hugr.get_parent(child).unwrap(); + let parent_needs = self + .0 + .get(&child) + .into_iter() + .flat_map(move |m| { + m.iter().filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) + .map(|(&w, ty)| (w, ty.clone())) + }) + .collect_vec(); + let any = !parent_needs.is_empty(); + if any { + self.0.entry(parent).or_default().extend(parent_needs); + } + any + } - while let Some(WorkItem { source, target, ty }) = non_local_edges.pop() { - let parent = hugr.get_parent(target.0).unwrap(); - let local_source = if hugr.get_parent(source.node()).unwrap() == parent { - source - } else if let Some(wire) = parent_source_map.get(&(parent,source)) { - *wire - } else { - let (workitem, wire) = match hugr.get_optype(parent).clone() { - OpType::DFG(mut dfg) => { - let new_port_index = dfg.signature.input.len(); - dfg.signature.input.to_mut().push(ty.clone()); - hugr.replace_op(parent, dfg).unwrap(); - let dfg_port = hugr.insert_incoming_port(parent, new_port_index); - hugr.connect(source.node(), source.source(), parent, dfg_port); - let wire = thread_dataflow_parent(hugr, parent, dfg_port.index(), ty.clone()); - let _ = parent_source_map.insert((parent, source), wire); - ( - WorkItem { - source, - target: (parent, dfg_port), - ty - }, - wire - ) - } - OpType::DataflowBlock(dataflow_block) => todo!(), - OpType::TailLoop(_) => { - let (workitem, wire) = do_tailloop(hugr, parent, source, ty); - let _ = parent_source_map.insert((parent, source), wire); - (workitem, wire) - } - OpType::Case(_) => { - let cond_node = hugr.get_parent(parent).unwrap(); - let mut cond = hugr.get_optype(cond_node).as_conditional().unwrap().clone(); - let new_port_index = cond.signature().input().len(); - cond.other_inputs.to_mut().push(ty.clone()); - hugr.replace_op(cond_node, cond).unwrap(); - let cond_port = hugr.insert_incoming_port(cond_node, new_port_index); - let mut this_wire = None; - for (case_n, mut case) in hugr.children(cond_node).filter_map(|n| { - let case = hugr.get_optype(n).as_case()?; - Some((n, case.clone())) - }).collect_vec() { - let case_port_index = case.signature.input().len(); - case.signature.input.to_mut().push(ty.clone()); - hugr.replace_op(case_n, case).unwrap(); - let case_input_wire = thread_dataflow_parent(hugr, case_n, case_port_index, ty.clone()); - let _ = parent_source_map.insert((case_n, source), case_input_wire); - if case_n == parent { - this_wire = Some(case_input_wire); - } - } - (WorkItem { source, target: (cond_node, cond_port), ty }, this_wire.unwrap()) + fn finish(mut self, hugr: impl HugrView) -> HashMap> { + let conds = self + .0 + .keys() + .copied() + .filter(|&n| hugr.get_optype(n).is_conditional()) + .collect_vec(); + for cond in conds { + if hugr.get_optype(cond).is_conditional() { + let cases = hugr + .children(cond) + .filter(|&child| hugr.get_optype(child).is_case()) + .collect_vec(); + let all_needed: BTreeMap<_, _> = cases + .iter() + .flat_map(|&case| { + let case_needed = self.0.get(&case); + case_needed + .into_iter() + .flat_map(|m| m.iter().map(|(&w, ty)| (w, ty.clone()))) + }) + .collect(); + for case in cases { + let _ = self.0.insert(case, all_needed.clone()); } - _ => panic!("impossible"), + } + } + self.0 + } +} + +pub fn remove_nonlocal_edges( + hugr: &mut impl HugrMut, + root: Node, +) -> Result<(), NonLocalEdgesError> { + let nonlocal_edges_map: HashMap = + nonlocal_edges(&DescendantsGraph::::try_new(hugr, root)?) + .map(|target @ (node, inport)| { + let source = { + let (n, p) = hugr.single_linked_output(node, inport).unwrap(); + Wire::new(n, p) + }; + debug_assert!( + hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() + ); + let Some(EdgeKind::Value(ty)) = + hugr.get_optype(source.node()).port_kind(source.source()) + else { + panic!("impossible") + }; + (node, WorkItem { source, target, ty }) + }) + .collect(); + + if nonlocal_edges_map.is_empty() { + return Ok(()); + } + + let bb_needs_sources_map = { + let nonlocal_sorted = { + let mut v = iter::successors(Some(vec![root]), |nodes| { + let children = nodes + .iter() + .flat_map(|&n| hugr.children(n)) + .collect_vec(); + (!children.is_empty()).then_some(children) + }) + .flatten() + .filter_map(|n| nonlocal_edges_map.get(&n)) + .collect_vec(); + v.reverse(); + v + }; + let mut parent_set = HashSet::::new(); + // earlier items are deeper in the heirarchy + let mut parent_worklist = VecDeque::::new(); + let mut add_parent = |p, wl: &mut VecDeque<_>| { + if parent_set.insert(p) { + wl.push_back(p); + } + }; + let mut bnsm = BBNeedsSourcesMapBuilder::default(); + for workitem in nonlocal_sorted { + let parent = hugr.get_parent(workitem.target.0).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.insert(parent, workitem.source, workitem.ty.clone()); + add_parent(parent, &mut parent_worklist); + } + + while let Some(bb_node) = parent_worklist.pop_front() { + let Some(parent) = hugr.get_parent(bb_node) else { + continue; }; - non_local_edges.push(workitem); - wire + if bnsm.extend_parent_needs_for(&hugr, bb_node) { + add_parent(parent, &mut parent_worklist); + } + } + bnsm.finish(&hugr) + }; + + #[cfg(debug_assertions)] + { + for (&n, wi) in nonlocal_edges_map.iter() { + let mut m = n; + loop { + let parent = hugr.get_parent(m).unwrap(); + if hugr.get_parent(wi.source.node()).unwrap() == parent { + break; + } + assert!(bb_needs_sources_map[&parent].contains_key(&wi.source)); + m = parent; + } + } + + for &bb in bb_needs_sources_map.keys() { + assert!(hugr.get_parent(bb).is_some()); + } + } + + let mut worklist = nonlocal_edges_map.into_values().collect_vec(); + let mut parent_source_map = ParentSourceMap::default(); + + for (bb, needs_sources) in bb_needs_sources_map { + worklist.extend(thread_sources( + &mut parent_source_map, + hugr, + bb, + needs_sources, + )); + } + + let parent_source_map = parent_source_map; + + while let Some(wi) = worklist.pop() { + let parent = hugr.get_parent(wi.target.0).unwrap(); + let source = if hugr.get_parent(wi.source.node()).unwrap() == parent { + wi.source + } else { + parent_source_map.get(parent, wi.source).unwrap() }; - hugr.disconnect(target.0, target.1); - hugr.connect( - local_source.node(), - local_source.source(), - target.0, - target.1, - ); + debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); + hugr.disconnect(wi.target.0, wi.target.1); + hugr.connect(source.node(), source.source(), wi.target.0, wi.target.1); } Ok(()) @@ -337,7 +550,8 @@ mod test { outer.finish_hugr_with_outputs([w1]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); - remove_nonlocal_edges(&mut hugr).unwrap(); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); hugr.validate().unwrap(); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } @@ -380,7 +594,8 @@ mod test { outer.finish_hugr_with_outputs([s1, s2, s3]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); - remove_nonlocal_edges(&mut hugr).unwrap(); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } @@ -391,31 +606,40 @@ mod test { let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); let mut hugr = { - let mut outer = DFGBuilder::new(Signature::new(vec![ - t1.clone(), - t2.clone(), - t3.clone() - ], out_type.clone())) + let mut outer = DFGBuilder::new(Signature::new( + vec![t1.clone(), t2.clone(), t3.clone()], + out_type.clone(), + )) .unwrap(); let [s1, s2, s3] = outer.input_wires_arr(); let [out] = { let mut cond = outer - .conditional_builder((vec![type_row![];3], s3), [], out_type.into()).unwrap(); + .conditional_builder((vec![type_row![]; 3], s3), [], out_type.into()) + .unwrap(); { let mut case = cond.case_builder(0).unwrap(); - let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]).unwrap().outputs_arr(); + let [r] = case + .add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]) + .unwrap() + .outputs_arr(); case.finish_with_outputs([r]).unwrap(); } { let mut case = cond.case_builder(1).unwrap(); - let [r] = case.add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]).unwrap().outputs_arr(); + let [r] = case + .add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]) + .unwrap() + .outputs_arr(); case.finish_with_outputs([r]).unwrap(); } { let mut case = cond.case_builder(2).unwrap(); let u = case.add_load_value(Value::unit()); - let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [u]).unwrap().outputs_arr(); + let [r] = case + .add_dataflow_op(Tag::new(0, out_variants.clone()), [u]) + .unwrap() + .outputs_arr(); case.finish_with_outputs([r]).unwrap(); } cond.finish_sub_container().unwrap().outputs_arr() @@ -423,7 +647,8 @@ mod test { outer.finish_hugr_with_outputs([out]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); - remove_nonlocal_edges(&mut hugr).unwrap(); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } From 5f09ba96fbd0b199c5c368625a4605124296978a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Feb 2025 17:18:31 +0000 Subject: [PATCH 07/61] works --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/non_local.rs | 654 +++++++++++++++++++++++------------ 2 files changed, 441 insertions(+), 214 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 39162a52e6..fe7db4be5a 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -25,6 +25,7 @@ paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } derive_more = { workspace = true, features = ["from", "error", "display"] } +delegate.workspace = true [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 26584d5a34..22bef220e8 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,8 +1,9 @@ //! This module provides functions for inspecting and modifying the nature of //! non local edges in a Hugr. +use delegate::delegate; use std::{ collections::{BTreeMap, HashMap, HashSet, VecDeque}, - iter, + iter, mem, }; //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions @@ -17,7 +18,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, - HugrView, IncomingPort, Node, Wire, + HugrView, IncomingPort, Node, PortIndex, Wire, }; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -93,220 +94,375 @@ struct WorkItem { ty: Type, } -fn thread_dataflow_parent( - hugr: &mut impl HugrMut, - parent: Node, - start_port_index: usize, - types: Vec, -) -> impl Iterator { - let [input_n, _] = hugr.get_io(parent).unwrap(); - let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { - panic!("impossible") - }; - let mut r = vec![]; - for (i, ty) in types.into_iter().enumerate() { - input - .types - .to_mut() - .insert(start_port_index + i, ty.clone()); - r.push(Wire::new( - input_n, - hugr.insert_outgoing_port(input_n, start_port_index + i), - )); +#[derive(Clone, Default, Debug)] +struct ParentSourceMap(HashMap>); + +impl ParentSourceMap { + // fn contains_parent(&self, parent: Node) -> bool { + // self.0.contains_key(&parent) + // } + + fn insert_sources_in_parent( + &mut self, + parent: Node, + sources: impl IntoIterator, + ) { + debug_assert!(!self.0.contains_key(&parent)); + self.0.entry(parent).or_default().extend(sources); } - hugr.replace_op(input_n, input).unwrap(); - r.into_iter() -} -fn do_tailloop( - parent_source_map: &mut ParentSourceMap, - hugr: &mut impl HugrMut, - node: Node, - sources: impl IntoIterator, -) -> impl Iterator { - let (sources, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); - let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); - let start_port_index = tailloop.just_inputs.len(); - { - tailloop.just_inputs.to_mut().extend(types.iter().cloned()); - hugr.replace_op(node, tailloop).unwrap(); + fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option { + self.0.get(&parent).and_then(|m| m.get(&source).cloned()) } - let tailloop_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) - .collect_vec(); - - let input_wires = - thread_dataflow_parent(hugr, node, start_port_index, types.clone()).collect_vec(); - parent_source_map.insert( - node, - iter::zip(sources.iter().copied(), input_wires.iter().copied()), - ); - - let [_, o] = hugr.get_io(node).unwrap(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = - hugr.get_optype(o).port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum_type() else { + + fn thread_dataflow_parent( + &mut self, + hugr: &mut impl HugrMut, + parent: Node, + start_port_index: usize, + sources: impl IntoIterator, + ) -> impl Iterator { + let [input_n, _] = hugr.get_io(parent).unwrap(); + let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { panic!("impossible") }; + let mut input_wires = vec![]; + self.0 + .entry(parent) + .or_default() + .extend(sources.into_iter().enumerate().map(|(i, (source, ty))| { + input.types.to_mut().insert(start_port_index + i, ty); + let input_wire = Wire::new( + input_n, + hugr.insert_outgoing_port(input_n, start_port_index + i), + ); + input_wires.push(input_wire); + (source, input_wire) + })); + hugr.replace_op(input_n, input).unwrap(); + input_wires.into_iter() + } +} - let old_sum_rows: Vec = sum_type - .iter_variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows = { - let mut v = old_sum_rows.clone(); - v[TailLoop::CONTINUE_TAG] - .to_mut() - .extend(types.iter().cloned()); - v - }; +#[derive(Clone, Debug)] +struct ThreadState<'a> { + parent_source_map: ParentSourceMap, + needs: &'a BBNeedsSourcesMap, + worklist: Vec, +} - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = - ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()).unwrap(); - for i in 0..2 { - let mut case = cond.case_builder(i).unwrap(); - let inputs = { - let all_inputs = case.input_wires(); - if i == TailLoop::CONTINUE_TAG { - Either::Left(all_inputs) +impl<'a> ThreadState<'a> { + delegate! { + to self.parent_source_map { + // fn contains_parent(&self, parent: Node) -> bool; + fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; + fn insert_sources_in_parent(&mut self, parent: Node, sources: impl IntoIterator); + fn thread_dataflow_parent( + &mut self, + hugr: &mut impl HugrMut, + parent: Node, + start_port_index: usize, + sources: impl IntoIterator, + ) -> impl Iterator; + } + } + + fn new(bbnsm: &'a BBNeedsSourcesMap) -> Self { + Self { + parent_source_map: Default::default(), + needs: bbnsm, + worklist: vec![], + } + } + + fn do_dataflow_block( + &mut self, + hugr: &mut impl HugrMut, + node: Node, + sources: Vec<(Wire, Type)>, + ) { + let types = sources.iter().map(|x| x.1.clone()).collect_vec(); + let new_sum_row_prefixes = { + let mut dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); + let mut nsrp = vec![vec![]; dfb.sum_rows.len()]; + dfb.inputs.to_mut().extend(types.clone()); + for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { + let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); + if hugr.get_optype(succ_n).is_exit_block() { + None } else { - Either::Right(all_inputs.into_iter().dropping_back(types.len())) + Some((out_p.index(), succ_n)) } + }) { + let succ_needs = &self.needs[&succ_n]; + let new_tys = succ_needs + .iter() + .map(|(&w, ty)| { + ( + sources.iter().find_position(|(x, _)| x == &w).unwrap().0, + ty.clone(), + ) + }) + .collect_vec(); + nsrp[this_p] = new_tys.clone(); + let tys = dfb.sum_rows[this_p].to_mut(); + let old_tys = mem::replace(tys, new_tys.into_iter().map(|x| x.1).collect_vec()); + tys.extend(old_tys); + } + hugr.replace_op(node, dfb).unwrap(); + nsrp + }; + + let input_wires = self + .thread_dataflow_parent(hugr, node, 0, sources.clone()) + .collect_vec(); + + let [_, o] = hugr.get_io(node).unwrap(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = + hugr.get_optype(o).port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum_type() else { + panic!("impossible") }; - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); + let old_sum_rows: Vec = sum_type + .iter_variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(new_sum_row_prefixes.clone(), old_sum_rows.iter()) + .map(|(new, old)| { + new.into_iter() + .map(|x| x.1) + .chain(old.iter().cloned()) + .collect_vec() + .into() + }) + .collect_vec(); + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + types.clone(), + new_control_type.clone(), + ) + .unwrap(); + for (i, row) in new_sum_row_prefixes.iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = vec![]; + for (source_i, _) in row { + args.push(case_inputs[old_sum_rows[i].len() + source_i]); + } + + args.extend(&case_inputs[..old_sum_rows[i].len()]); + + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(node, cond).new_root; + let (n, p) = hugr.single_linked_output(o, 0).unwrap(); + hugr.connect(n, p, cond_node, 0); + for (i, w) in input_wires.into_iter().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); } - (cond.finish_hugr().unwrap(), new_control_type) - }; - let cond_node = hugr.insert_hugr(node, cond).new_root; - let (n, p) = hugr.single_linked_output(o, 0).unwrap(); - hugr.connect(n, p, cond_node, 0); - for (i, w) in input_wires.into_iter().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); + hugr.disconnect(o, IncomingPort::from(0)); + hugr.connect(cond_node, 0, o, 0); + let mut output = hugr.get_optype(o).as_output().unwrap().clone(); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(o, output).unwrap(); } - hugr.disconnect(o, IncomingPort::from(0)); - hugr.connect(cond_node, 0, o, 0); - let mut output = hugr.get_optype(o).as_output().unwrap().clone(); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(o, output).unwrap(); - mk_workitems(node, sources, tailloop_ports, types) -} -#[derive(Clone, Default, Debug)] -struct ParentSourceMap(HashMap>); + fn do_cfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let types = sources.iter().map(|x| x.1.clone()).collect_vec(); + { + let mut cfg = hugr.get_optype(node).as_cfg().unwrap().clone(); + let inputs = cfg.signature.input.to_mut(); + let old_inputs = mem::replace(inputs, types); + inputs.extend(old_inputs); + hugr.replace_op(node, cfg).unwrap(); + } + let new_cond_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, i)) + .collect_vec(); + self.insert_sources_in_parent(node, iter::empty()); + self.worklist + .extend(mk_workitems(node, sources, new_cond_ports)) + } -impl ParentSourceMap { - fn contains_parent(&self, parent: Node) -> bool { - self.0.contains_key(&parent) + fn do_dfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut dfg = hugr.get_optype(node).as_dfg().unwrap().clone(); + let start_new_port_index = dfg.signature.input().len(); + let new_dfg_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) + .collect_vec(); + dfg.signature + .input + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + hugr.replace_op(node, dfg).unwrap(); + let _ = + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); + self.worklist + .extend(mk_workitems(node, sources, new_dfg_ports)); } - fn insert(&mut self, parent: Node, sources: impl IntoIterator) { - debug_assert!(!self.0.contains_key(&parent)); - self.0.entry(parent).or_default().extend(sources); + fn do_conditional(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut cond = hugr.get_optype(node).as_conditional().unwrap().clone(); + let start_new_port_index = cond.signature().input().len(); + cond.other_inputs + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + hugr.replace_op(node, cond).unwrap(); + let new_cond_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) + .collect_vec(); + self.insert_sources_in_parent(node, iter::empty()); + self.worklist + .extend(mk_workitems(node, sources, new_cond_ports)) } - fn get(&self, parent: Node, source: Wire) -> Option { - self.0.get(&parent).and_then(|m| m.get(&source).cloned()) + fn do_case(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut case = hugr.get_optype(node).as_case().unwrap().clone(); + let start_case_port_index = case.signature.input().len(); + case.signature + .input + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + hugr.replace_op(node, case).unwrap(); + let _ = self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); + } + + fn do_tailloop(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); + let types = sources.iter().map(|x| x.1.clone()).collect_vec(); + let start_port_index = tailloop.just_inputs.len(); + { + tailloop.just_inputs.to_mut().extend(types.clone()); + hugr.replace_op(node, tailloop).unwrap(); + } + let tailloop_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) + .collect_vec(); + + let input_wires = self + .thread_dataflow_parent(hugr, node, start_port_index, sources.clone()) + .collect_vec(); + + let [_, o] = hugr.get_io(node).unwrap(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = + hugr.get_optype(o).port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum_type() else { + panic!("impossible") + }; + + let old_sum_rows: Vec = sum_type + .iter_variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows = { + let mut v = old_sum_rows.clone(); + v[TailLoop::CONTINUE_TAG] + .to_mut() + .extend(types.iter().cloned()); + v + }; + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = + ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()) + .unwrap(); + for i in 0..2 { + let mut case = cond.case_builder(i).unwrap(); + let inputs = { + let all_inputs = case.input_wires(); + if i == TailLoop::CONTINUE_TAG { + Either::Left(all_inputs) + } else { + Either::Right(all_inputs.into_iter().dropping_back(types.len())) + } + }; + + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(node, cond).new_root; + let (n, p) = hugr.single_linked_output(o, 0).unwrap(); + hugr.connect(n, p, cond_node, 0); + for (i, w) in input_wires.into_iter().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } + hugr.disconnect(o, IncomingPort::from(0)); + hugr.connect(cond_node, 0, o, 0); + let mut output = hugr.get_optype(o).as_output().unwrap().clone(); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(o, output).unwrap(); + self.worklist + .extend(mk_workitems(node, sources, tailloop_ports)) + } + + fn finish(self, _hugr: &mut impl HugrMut) -> (Vec, ParentSourceMap) { + (self.worklist, self.parent_source_map) } } +fn thread_sources( + hugr: &mut impl HugrMut, + bb_needs_sources_map: &BBNeedsSourcesMap, +) -> (Vec, ParentSourceMap) { + let mut state = ThreadState::new(bb_needs_sources_map); + for (&bb, sources) in bb_needs_sources_map { + let sources = sources + .iter() + .map(|(&w, ty)| (w, ty.clone())) + .collect_vec(); + match hugr.get_optype(bb).clone() { + OpType::DFG(_) => state.do_dfg(hugr, bb, sources), + OpType::Conditional(_) => state.do_conditional(hugr, bb, sources), + OpType::Case(_) => state.do_case(hugr, bb, sources), + OpType::TailLoop(_) => state.do_tailloop(hugr, bb, sources), + OpType::DataflowBlock(_) => state.do_dataflow_block(hugr, bb, sources), + OpType::CFG(_) => state.do_cfg(hugr, bb, sources), + _ => panic!("impossible"), + } + } + + state.finish(hugr) +} + fn mk_workitems( node: Node, - sources: impl IntoIterator, + sources: impl IntoIterator, ports: impl IntoIterator, - types: impl IntoIterator, ) -> impl Iterator { - itertools::izip!(sources, ports, types).map(move |(source, p, ty)| WorkItem { + itertools::izip!(sources, ports).map(move |((source, ty), p)| WorkItem { source, target: (node, p), ty, }) } -fn thread_sources( - parent_source_map: &mut ParentSourceMap, - hugr: &mut impl HugrMut, - bb: Node, - sources: impl IntoIterator, -) -> Vec { - let (source_wires, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); - match hugr.get_optype(bb).clone() { - OpType::DFG(mut dfg) => { - debug_assert!(!parent_source_map.contains_parent(bb)); - let start_new_port_index = dfg.signature.input().len(); - let new_dfg_ports = (0..source_wires.len()) - .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) - .collect_vec(); - dfg.signature.input.to_mut().extend(types.clone()); - hugr.replace_op(bb, dfg).unwrap(); - for (source, &target) in iter::zip(source_wires.iter(), new_dfg_ports.iter()) { - hugr.connect(source.node(), source.source(), bb, target); - } - parent_source_map.insert( - bb, - iter::zip( - source_wires.iter().copied(), - thread_dataflow_parent(hugr, bb, start_new_port_index, types.clone()), - ), - ); - mk_workitems(bb, source_wires, new_dfg_ports, types).collect_vec() - } - OpType::Conditional(mut cond) => { - debug_assert!(!parent_source_map.contains_parent(bb)); - let start_new_port_index = cond.signature().input().len(); - cond.other_inputs.to_mut().extend(types.clone()); - hugr.replace_op(bb, cond).unwrap(); - let new_cond_ports = (0..source_wires.len()) - .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) - .collect_vec(); - parent_source_map.insert(bb, iter::empty()); - mk_workitems(bb, source_wires, new_cond_ports, types).collect_vec() - } - OpType::Case(mut case) => { - debug_assert!(!parent_source_map.contains_parent(bb)); - let start_case_port_index = case.signature.input().len(); - case.signature.input.to_mut().extend(types.clone()); - hugr.replace_op(bb, case).unwrap(); - parent_source_map.insert( - bb, - iter::zip( - source_wires.iter().copied(), - thread_dataflow_parent(hugr, bb, start_case_port_index, types), - ), - ); - vec![] - } - OpType::TailLoop(_) => { - do_tailloop(parent_source_map, hugr, bb, iter::zip(source_wires, types)).collect_vec() - } - _ => panic!("impossible"), - } - // _ => panic!("impossible"), - // }; - // non_local_edges.push(workitem); - // wire - // }; - // hugr.disconnect(target.0, target.1); - // hugr.connect( - // local_source.node(), - // local_source.source(), - // target.0, - // target.1, - // ); - // } -} +type BBNeedsSourcesMap = HashMap>; #[derive(Debug, Default, Clone)] -struct BBNeedsSourcesMapBuilder(HashMap>); +struct BBNeedsSourcesMapBuilder(BBNeedsSourcesMap); impl BBNeedsSourcesMapBuilder { fn insert(&mut self, bb: Node, source: Wire, ty: Type) { @@ -320,7 +476,8 @@ impl BBNeedsSourcesMapBuilder { .get(&child) .into_iter() .flat_map(move |m| { - m.iter().filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) + m.iter() + .filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) .map(|(&w, ty)| (w, ty.clone())) }) .collect_vec(); @@ -331,15 +488,15 @@ impl BBNeedsSourcesMapBuilder { any } - fn finish(mut self, hugr: impl HugrView) -> HashMap> { - let conds = self - .0 - .keys() - .copied() - .filter(|&n| hugr.get_optype(n).is_conditional()) - .collect_vec(); - for cond in conds { - if hugr.get_optype(cond).is_conditional() { + fn finish(mut self, hugr: impl HugrView) -> BBNeedsSourcesMap { + { + let conds = self + .0 + .keys() + .copied() + .filter(|&n| hugr.get_optype(n).is_conditional()) + .collect_vec(); + for cond in conds { let cases = hugr .children(cond) .filter(|&child| hugr.get_optype(child).is_case()) @@ -358,6 +515,40 @@ impl BBNeedsSourcesMapBuilder { } } } + { + let cfgs = self + .0 + .keys() + .copied() + .filter(|&n| hugr.get_optype(n).is_cfg() && self.0.contains_key(&n)) + .collect_vec(); + for cfg in cfgs { + let dfbs = hugr + .children(cfg) + .filter(|&child| hugr.get_optype(child).is_dataflow_block()) + .collect_vec(); + + // let mut dfb_needs_map: HashMap<_, _> = dfbs + // .iter() + // .map(|&n| (n, self.0.get(&n).cloned().unwrap_or_default())) + // .collect(); + loop { + let mut any_change = false; + for &dfb in dfbs.iter() { + for succ_n in hugr.output_neighbours(dfb) { + for (w, ty) in self.0.get(&succ_n).cloned().unwrap_or_default() { + any_change |= + self.0.entry(dfb).or_default().insert(w, ty).is_none(); + } + } + } + if !any_change { + break; + } + } + } + } + self.0 } } @@ -392,10 +583,7 @@ pub fn remove_nonlocal_edges( let bb_needs_sources_map = { let nonlocal_sorted = { let mut v = iter::successors(Some(vec![root]), |nodes| { - let children = nodes - .iter() - .flat_map(|&n| hugr.children(n)) - .collect_vec(); + let children = nodes.iter().flat_map(|&n| hugr.children(n)).collect_vec(); (!children.is_empty()).then_some(children) }) .flatten() @@ -450,26 +638,21 @@ pub fn remove_nonlocal_edges( } } - let mut worklist = nonlocal_edges_map.into_values().collect_vec(); - let mut parent_source_map = ParentSourceMap::default(); - - for (bb, needs_sources) in bb_needs_sources_map { - worklist.extend(thread_sources( - &mut parent_source_map, - hugr, - bb, - needs_sources, - )); - } - - let parent_source_map = parent_source_map; + let (parent_source_map, worklist) = { + let mut worklist = nonlocal_edges_map.into_values().collect_vec(); + let (wl, psm) = thread_sources(hugr, &bb_needs_sources_map); + worklist.extend(wl); + (psm, worklist) + }; - while let Some(wi) = worklist.pop() { + for wi in worklist { let parent = hugr.get_parent(wi.target.0).unwrap(); let source = if hugr.get_parent(wi.source.node()).unwrap() == parent { wi.source } else { - parent_source_map.get(parent, wi.source).unwrap() + parent_source_map + .get_source_in_parent(parent, wi.source) + .unwrap() }; debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); hugr.disconnect(wi.target.0, wi.target.1); @@ -537,7 +720,7 @@ mod test { } #[test] - fn remove_nonlocal_edges_dfg() { + fn dfg() { let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [w0] = outer.input_wires_arr(); @@ -557,7 +740,7 @@ mod test { } #[test] - fn remove_nonlocal_edges_tailloop() { + fn tailloop() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(vec![ @@ -601,7 +784,7 @@ mod test { } #[test] - fn remove_nonlocal_edges_cond() { + fn conditional() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); @@ -652,4 +835,47 @@ mod test { hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } + + #[test] + fn cfg() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + // let out_variants = vec![t1.clone().into(), t2.clone().into()]; + let out_type = t1.clone(); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new( + vec![t1.clone(), t2.clone(), t3.clone()], + out_type.clone(), + )) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [out] = { + let mut cfg = outer.cfg_builder([], out_type.into()).unwrap(); + + let entry = { + let mut entry = cfg.entry_builder([type_row![]], type_row![]).unwrap(); + let w = entry.add_load_value(Value::unit()); + entry.finish_with_outputs(w, []).unwrap() + }; + let exit = cfg.exit_block(); + + let bb1 = { + let mut entry = cfg + .block_builder(type_row![], [type_row![]], t1.clone().into()) + .unwrap(); + let w = entry.add_load_value(Value::unit()); + entry.finish_with_outputs(w, [s1]).unwrap() + }; + cfg.branch(&entry, 0, &bb1).unwrap(); + cfg.branch(&bb1, 0, &exit).unwrap(); + cfg.finish_sub_container().unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([out]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); + println!("{}", hugr.mermaid_string()); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } } From 8f0b164b693e62448573e8467c100f15f6de1fae Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sat, 8 Feb 2025 14:19:48 +0000 Subject: [PATCH 08/61] wip --- hugr-passes/src/non_local.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 22bef220e8..09feb42c14 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -720,7 +720,7 @@ mod test { } #[test] - fn dfg() { + fn unnonlocal_dfg() { let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [w0] = outer.input_wires_arr(); @@ -740,7 +740,7 @@ mod test { } #[test] - fn tailloop() { + fn unnonlocal_tailloop() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(vec![ @@ -784,7 +784,7 @@ mod test { } #[test] - fn conditional() { + fn unnonlocal_conditional() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); @@ -837,7 +837,7 @@ mod test { } #[test] - fn cfg() { + fn unnonlocal_cfg() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); // let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = t1.clone(); From af2959a77b4cbeb4d0042b94dbab56640b6a3edb Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sat, 8 Feb 2025 14:26:39 +0000 Subject: [PATCH 09/61] feat: Add `Type::as_sum` and `SumType::variants`. --- .../std_extensions/arithmetic/float_types.rs | 1 + .../std_extensions/arithmetic/int_types.rs | 1 + .../src/std_extensions/collections/array.rs | 1 + hugr-core/src/types.rs | 31 +++++++++++++++---- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 304b940453..579d89e6bb 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -65,6 +65,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Name of the constructor for creating constant 64bit floats. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const-f64"; /// Create a new [`ConstF64`] diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 1342dd9320..e5d625695e 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -105,6 +105,7 @@ pub struct ConstInt { impl ConstInt { /// Name of the constructor for creating constant integers. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const"; /// Create a new [`ConstInt`] with a given width and unsigned value diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 618bd61826..93f58727cb 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -43,6 +43,7 @@ pub struct ArrayValue { impl ArrayValue { /// Name of the constructor for creating constant arrays. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "collections.array.const"; /// Create a new [CustomConst] for an array of values of type `typ`. diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 962e00876c..c22c1fff8a 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -27,7 +27,7 @@ pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; use itertools::FoldWhile::{Continue, Done}; -use itertools::{repeat_n, Itertools}; +use itertools::{Either, Itertools as _}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; @@ -189,7 +189,7 @@ impl std::fmt::Display for SumType { SumType::Unit { size: 1 } => write!(f, "Unit"), SumType::Unit { size: 2 } => write!(f, "Bool"), SumType::Unit { size } => { - display_list_with_separator(repeat_n("[]", *size as usize), f, "+") + display_list_with_separator(itertools::repeat_n("[]", *size as usize), f, "+") } SumType::General { rows } => match rows.len() { 1 if rows[0].is_empty() => write!(f, "Unit"), @@ -216,17 +216,17 @@ impl SumType { } } - /// New UnitSum with empty Tuple variants + /// New UnitSum with empty Tuple variants. pub const fn new_unary(size: u8) -> Self { Self::Unit { size } } - /// New tuple (single row of variants) + /// New tuple (single row of variants). pub fn new_tuple(types: impl Into) -> Self { Self::new([types.into()]) } - /// New option type (either an empty option, or a row of types) + /// New option type (either an empty option, or a row of types). pub fn new_option(types: impl Into) -> Self { Self::new([vec![].into(), types.into()]) } @@ -248,7 +248,7 @@ impl SumType { } } - /// Returns variant row if there is only one variant + /// Returns variant row if there is only one variant. pub fn as_tuple(&self) -> Option<&TypeRowRV> { match self { SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF), @@ -256,6 +256,17 @@ impl SumType { _ => None, } } + + /// Returns an iterator over the variants. + pub fn variants(&self) -> impl Iterator { + match self { + SumType::Unit { size } => Either::Left(itertools::repeat_n( + TypeRV::EMPTY_TYPEROW_REF, + *size as usize, + )), + SumType::General { rows } => Either::Right(rows.iter()), + } + } } impl From for TypeBase { @@ -453,6 +464,14 @@ impl TypeBase { &mut self.0 } + /// Returns the inner [SumType] if the type is a sum. + pub fn as_sum(&self) -> Option<&SumType> { + match &self.0 { + TypeEnum::Sum(s) => Some(s), + _ => None, + } + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { From 2e3e282ef8aff7c5c0cd76c34b2b187839622a2a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Feb 2025 09:04:42 +0000 Subject: [PATCH 10/61] coverage --- hugr-core/src/types.rs | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index c22c1fff8a..fa10e2a609 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -732,13 +732,37 @@ pub(crate) mod test { assert_eq!(pred1, Type::from(pred_direct)); } + #[test] + fn as_sum() { + let t = Type::new_unit_sum(0); + assert!(t.as_sum().is_some()); + } + + #[test] + fn sum_variants() { + { + let variants: Vec = vec![ + TypeRV::UNIT.into(), + vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(), + ]; + let t = SumType::new(variants.clone()); + assert_eq!(variants, t.variants().cloned().collect_vec()); + } + { + assert_eq!( + vec![&TypeRV::EMPTY_TYPEROW;3], + SumType::new_unary(3).variants().collect_vec() + ); + } + } + mod proptest { use crate::proptest::RecursionDepth; use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum}; use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV}; - use ::proptest::prelude::*; + use proptest::prelude::*; impl Arbitrary for super::SumType { type Parameters = RecursionDepth; From dc95bc972666b3823ad3d134ac26891e32fd3ff8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Feb 2025 09:48:27 +0000 Subject: [PATCH 11/61] wip --- hugr-passes/src/non_local.rs | 84 +++++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 09feb42c14..71b127e89e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -224,12 +224,12 @@ impl<'a> ThreadState<'a> { else { panic!("impossible") }; - let Some(sum_type) = control_type.as_sum_type() else { + let Some(sum_type) = control_type.as_sum() else { panic!("impossible") }; let old_sum_rows: Vec = sum_type - .iter_variants() + .variants() .map(|x| x.clone().try_into().unwrap()) .collect_vec(); let new_sum_rows: Vec = @@ -364,12 +364,12 @@ impl<'a> ThreadState<'a> { else { panic!("impossible") }; - let Some(sum_type) = control_type.as_sum_type() else { + let Some(sum_type) = control_type.as_sum() else { panic!("impossible") }; let old_sum_rows: Vec = sum_type - .iter_variants() + .variants() .map(|x| x.clone().try_into().unwrap()) .collect_vec(); let new_sum_rows = { @@ -665,8 +665,8 @@ pub fn remove_nonlocal_edges( #[cfg(test)] mod test { use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, - extension::prelude::{bool_t, Noop}, + builder::{Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, + extension::prelude::{bool_t, either_type, option_type, Noop}, ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, types::Signature, @@ -841,35 +841,77 @@ mod test { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); // let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = t1.clone(); + // Cfg consists of 4 dataflow blocks and an exit block + // + // The 4 dataflow blocks form a diamond, and the bottom block branches + // either to the entry block or the exit block. + // + // Two non-local uses in the left block means that these values must + // be threaded through all blocks, because of the loop. + // + // All non-trivial(i.e. more than one choice of successor) branching is + // done on an option type to exercise both empty and occupied control + // sums. + // + // All branches have an other-output. let mut hugr = { + let branch_sum_type = either_type(Type::UNIT, Type::UNIT); + let branch_type = Type::from(branch_sum_type.clone()); + let branch_variants = branch_sum_type.variants().cloned().map(|x| x.try_into().unwrap()).collect_vec(); + let nonlocal1_type = bool_t(); + let nonlocal2_type = Type::new_unit_sum(3); + let other_output_type = branch_type.clone(); let mut outer = DFGBuilder::new(Signature::new( - vec![t1.clone(), t2.clone(), t3.clone()], - out_type.clone(), + vec![branch_type.clone(), nonlocal1_type.clone(), nonlocal2_type.clone(), Type::UNIT], + vec![Type::UNIT, other_output_type.clone()] )) .unwrap(); - let [s1, s2, s3] = outer.input_wires_arr(); - let [out] = { - let mut cfg = outer.cfg_builder([], out_type.into()).unwrap(); + let [b, nl1, nl2, unit] = outer.input_wires_arr(); + let [unit, out] = { + let mut cfg = outer.cfg_builder([(Type::UNIT, unit), (branch_type.clone(), b)], vec![Type::UNIT, other_output_type.clone()].into()).unwrap(); let entry = { - let mut entry = cfg.entry_builder([type_row![]], type_row![]).unwrap(); - let w = entry.add_load_value(Value::unit()); - entry.finish_with_outputs(w, []).unwrap() + let entry = cfg.entry_builder(branch_variants.clone(), other_output_type.clone().into()).unwrap(); + let [_, b] = entry.input_wires_arr(); + + entry.finish_with_outputs(b, [b]).unwrap() }; let exit = cfg.exit_block(); - let bb1 = { + let bb_left = { let mut entry = cfg - .block_builder(type_row![], [type_row![]], t1.clone().into()) + .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .unwrap(); + let [unit, oo] = entry.input_wires_arr(); + let [_] = entry.add_dataflow_op(Noop::new(nonlocal1_type), [nl1]).unwrap().outputs_arr(); + let [_] = entry.add_dataflow_op(Noop::new(nonlocal2_type), [nl2]).unwrap().outputs_arr(); + entry.finish_with_outputs(unit, [oo]).unwrap() + }; + + let bb_right = { + let entry = cfg + .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .unwrap(); + let [b, oo] = entry.input_wires_arr(); + entry.finish_with_outputs(unit, [oo]).unwrap() + }; + + let bb_bottom = { + let entry = cfg + .block_builder(branch_type.clone().into(), branch_variants, other_output_type.clone().into()) .unwrap(); - let w = entry.add_load_value(Value::unit()); - entry.finish_with_outputs(w, [s1]).unwrap() + let [oo] = entry.input_wires_arr(); + entry.finish_with_outputs(oo, [oo]).unwrap() }; - cfg.branch(&entry, 0, &bb1).unwrap(); - cfg.branch(&bb1, 0, &exit).unwrap(); + cfg.branch(&entry, 0, &bb_left).unwrap(); + cfg.branch(&entry, 1, &bb_right).unwrap(); + cfg.branch(&bb_left, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_right, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_bottom, 0, &entry).unwrap(); + cfg.branch(&bb_bottom, 1, &exit).unwrap(); cfg.finish_sub_container().unwrap().outputs_arr() }; - outer.finish_hugr_with_outputs([out]).unwrap() + outer.finish_hugr_with_outputs([unit, out]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); let root = hugr.root(); From 9f619b47f5a145e0533439cc9dfbdae9b56314f5 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 09:37:33 +0000 Subject: [PATCH 12/61] wip --- hugr-core/src/types.rs | 2 +- hugr-passes/src/non_local.rs | 123 +++++++++++++++++++++-------------- 2 files changed, 76 insertions(+), 49 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index fa10e2a609..d69c7ef7d6 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -750,7 +750,7 @@ pub(crate) mod test { } { assert_eq!( - vec![&TypeRV::EMPTY_TYPEROW;3], + vec![&TypeRV::EMPTY_TYPEROW; 3], SumType::new_unary(3).variants().collect_vec() ); } diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 71b127e89e..8600f19c50 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -155,7 +155,7 @@ impl<'a> ThreadState<'a> { delegate! { to self.parent_source_map { // fn contains_parent(&self, parent: Node) -> bool; - fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; + // fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; fn insert_sources_in_parent(&mut self, parent: Node, sources: impl IntoIterator); fn thread_dataflow_parent( &mut self, @@ -183,33 +183,24 @@ impl<'a> ThreadState<'a> { ) { let types = sources.iter().map(|x| x.1.clone()).collect_vec(); let new_sum_row_prefixes = { - let mut dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); - let mut nsrp = vec![vec![]; dfb.sum_rows.len()]; - dfb.inputs.to_mut().extend(types.clone()); + let mut this_dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); + let mut nsrp = vec![vec![]; this_dfb.sum_rows.len()]; + vec_prepend(this_dfb.inputs.to_mut(), types.clone()); + for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); - if hugr.get_optype(succ_n).is_exit_block() { - None - } else { - Some((out_p.index(), succ_n)) - } + hugr.get_optype(succ_n).is_dataflow_block().then_some((out_p.index(), succ_n)) }) { let succ_needs = &self.needs[&succ_n]; - let new_tys = succ_needs + let succ_needs_source_indices = succ_needs .iter() - .map(|(&w, ty)| { - ( - sources.iter().find_position(|(x, _)| x == &w).unwrap().0, - ty.clone(), - ) - }) + .map(|(&w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) .collect_vec(); - nsrp[this_p] = new_tys.clone(); - let tys = dfb.sum_rows[this_p].to_mut(); - let old_tys = mem::replace(tys, new_tys.into_iter().map(|x| x.1).collect_vec()); - tys.extend(old_tys); + let succ_needs_tys = succ_needs_source_indices.iter().copied().map(|x| sources[x].1.clone()).collect_vec(); + vec_prepend(this_dfb.sum_rows[this_p].to_mut(), succ_needs_tys); + nsrp[this_p] = succ_needs_source_indices; } - hugr.replace_op(node, dfb).unwrap(); + hugr.replace_op(node, this_dfb).unwrap(); nsrp }; @@ -233,11 +224,11 @@ impl<'a> ThreadState<'a> { .map(|x| x.clone().try_into().unwrap()) .collect_vec(); let new_sum_rows: Vec = - itertools::zip_eq(new_sum_row_prefixes.clone(), old_sum_rows.iter()) - .map(|(new, old)| { - new.into_iter() - .map(|x| x.1) - .chain(old.iter().cloned()) + itertools::zip_eq(new_sum_row_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_source_indices, old_tys)| { + new_source_indices.into_iter() + .map(|&x| sources[x].1.clone()) + .chain(old_tys.iter().cloned()) .collect_vec() .into() }) @@ -250,11 +241,11 @@ impl<'a> ThreadState<'a> { new_control_type.clone(), ) .unwrap(); - for (i, row) in new_sum_row_prefixes.iter().enumerate() { + for (i, new_source_indices) in new_sum_row_prefixes.into_iter().enumerate() { let mut case = cond.case_builder(i).unwrap(); let case_inputs = case.input_wires().collect_vec(); let mut args = vec![]; - for (source_i, _) in row { + for source_i in new_source_indices { args.push(case_inputs[old_sum_rows[i].len() + source_i]); } @@ -279,6 +270,7 @@ impl<'a> ThreadState<'a> { let mut output = hugr.get_optype(o).as_output().unwrap().clone(); output.types.to_mut()[0] = new_control_type; hugr.replace_op(o, output).unwrap(); + dbg!(hugr.single_linked_output(o, 0)); } fn do_cfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { @@ -429,10 +421,7 @@ fn thread_sources( ) -> (Vec, ParentSourceMap) { let mut state = ThreadState::new(bb_needs_sources_map); for (&bb, sources) in bb_needs_sources_map { - let sources = sources - .iter() - .map(|(&w, ty)| (w, ty.clone())) - .collect_vec(); + let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); match hugr.get_optype(bb).clone() { OpType::DFG(_) => state.do_dfg(hugr, bb, sources), OpType::Conditional(_) => state.do_conditional(hugr, bb, sources), @@ -662,11 +651,18 @@ pub fn remove_nonlocal_edges( Ok(()) } +fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { + let mut old_v = mem::replace(v, ts.into_iter().collect()); + v.extend(old_v.drain(..)); +} + #[cfg(test)] mod test { use hugr_core::{ - builder::{Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, - extension::prelude::{bool_t, either_type, option_type, Noop}, + builder::{ + DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, + }, + extension::prelude::{bool_t, either_type, Noop}, ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, types::Signature, @@ -838,9 +834,6 @@ mod test { #[test] fn unnonlocal_cfg() { - let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); - // let out_variants = vec![t1.clone().into(), t2.clone().into()]; - let out_type = t1.clone(); // Cfg consists of 4 dataflow blocks and an exit block // // The 4 dataflow blocks form a diamond, and the bottom block branches @@ -857,21 +850,37 @@ mod test { let mut hugr = { let branch_sum_type = either_type(Type::UNIT, Type::UNIT); let branch_type = Type::from(branch_sum_type.clone()); - let branch_variants = branch_sum_type.variants().cloned().map(|x| x.try_into().unwrap()).collect_vec(); + let branch_variants = branch_sum_type + .variants() + .cloned() + .map(|x| x.try_into().unwrap()) + .collect_vec(); let nonlocal1_type = bool_t(); let nonlocal2_type = Type::new_unit_sum(3); let other_output_type = branch_type.clone(); let mut outer = DFGBuilder::new(Signature::new( - vec![branch_type.clone(), nonlocal1_type.clone(), nonlocal2_type.clone(), Type::UNIT], - vec![Type::UNIT, other_output_type.clone()] + vec![ + branch_type.clone(), + nonlocal1_type.clone(), + nonlocal2_type.clone(), + Type::UNIT, + ], + vec![Type::UNIT, other_output_type.clone()], )) .unwrap(); let [b, nl1, nl2, unit] = outer.input_wires_arr(); let [unit, out] = { - let mut cfg = outer.cfg_builder([(Type::UNIT, unit), (branch_type.clone(), b)], vec![Type::UNIT, other_output_type.clone()].into()).unwrap(); + let mut cfg = outer + .cfg_builder( + [(Type::UNIT, unit), (branch_type.clone(), b)], + vec![Type::UNIT, other_output_type.clone()].into(), + ) + .unwrap(); let entry = { - let entry = cfg.entry_builder(branch_variants.clone(), other_output_type.clone().into()).unwrap(); + let entry = cfg + .entry_builder(branch_variants.clone(), other_output_type.clone().into()) + .unwrap(); let [_, b] = entry.input_wires_arr(); entry.finish_with_outputs(b, [b]).unwrap() @@ -880,25 +889,43 @@ mod test { let bb_left = { let mut entry = cfg - .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) .unwrap(); let [unit, oo] = entry.input_wires_arr(); - let [_] = entry.add_dataflow_op(Noop::new(nonlocal1_type), [nl1]).unwrap().outputs_arr(); - let [_] = entry.add_dataflow_op(Noop::new(nonlocal2_type), [nl2]).unwrap().outputs_arr(); + let [_] = entry + .add_dataflow_op(Noop::new(nonlocal1_type), [nl1]) + .unwrap() + .outputs_arr(); + let [_] = entry + .add_dataflow_op(Noop::new(nonlocal2_type), [nl2]) + .unwrap() + .outputs_arr(); entry.finish_with_outputs(unit, [oo]).unwrap() }; let bb_right = { let entry = cfg - .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) .unwrap(); - let [b, oo] = entry.input_wires_arr(); + let [_b, oo] = entry.input_wires_arr(); entry.finish_with_outputs(unit, [oo]).unwrap() }; let bb_bottom = { let entry = cfg - .block_builder(branch_type.clone().into(), branch_variants, other_output_type.clone().into()) + .block_builder( + branch_type.clone().into(), + branch_variants, + other_output_type.clone().into(), + ) .unwrap(); let [oo] = entry.input_wires_arr(); entry.finish_with_outputs(oo, [oo]).unwrap() From de0c513b46af102f93a064ae98ad0a2ff1df752a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Feb 2025 08:54:54 +0000 Subject: [PATCH 13/61] feat: Add `HugrMutInternals::insert_ports` --- hugr-core/src/hugr/internal.rs | 97 +++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 3f1c6b6ff7..1f67ff873a 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -6,7 +6,8 @@ use std::rc::Rc; use std::sync::Arc; use delegate::delegate; -use portgraph::{LinkView, MultiPortGraph, PortMut, PortView}; +use itertools::Itertools; +use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; use crate::ops::handle::NodeHandle; use crate::ops::OpTrait; @@ -174,6 +175,26 @@ pub trait HugrMutInternals: RootTagged { self.hugr_mut().add_ports(node, direction, amount) } + /// Insert `amount` new ports for a node, starting at `index`. The + /// `direction` parameter specifies whether to add ports to the incoming or + /// outgoing list. Links from this node are preserved, even when ports are + /// renumbered by the insertion. + /// + /// Returns the range of newly created ports. + /// # Panics + /// + /// If the node is not in the graph. + fn insert_ports( + &mut self, + node: Node, + direction: Direction, + index: usize, + amount: usize, + ) -> Range { + panic_invalid_node(self, node); + self.hugr_mut().insert_ports(node, direction, index, amount) + } + /// Sets the parent of a node. /// /// The node becomes the parent's last child. @@ -260,6 +281,46 @@ impl + AsMut> HugrMutInternals for T { .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) } + fn insert_ports( + &mut self, + node: Node, + direction: Direction, + index: usize, + amount: usize, + ) -> Range { + let old_num_ports = match direction { + Direction::Incoming => self.base_hugr().graph.num_inputs(node.pg_index()), + Direction::Outgoing => self.base_hugr().graph.num_outputs(node.pg_index()), + }; + + let new_ports = self.add_ports(node, direction, amount as isize); + + for swap_from_port in (index..old_num_ports).rev() { + let swap_to_port = swap_from_port + amount; + let [from_port_index, to_port_index] = [swap_from_port, swap_to_port].map(|p| { + self.base_hugr() + .graph + .port_index(node.pg_index(), PortOffset::new(direction, p)) + .unwrap() + }); + let linked_ports = self + .base_hugr() + .graph + .port_links(from_port_index) + .map(|(_, to_subport)| to_subport.port()) + .collect_vec(); + self.hugr_mut().graph.unlink_port(from_port_index); + for linked_port_index in linked_ports { + let _ = self + .hugr_mut() + .graph + .link_ports(to_port_index, linked_port_index) + .expect("Ports exist"); + } + } + index..new_ports.len() + } + fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { let mut incoming = self.hugr_mut().graph.num_inputs(node.pg_index()); let mut outgoing = self.hugr_mut().graph.num_outputs(node.pg_index()); @@ -309,3 +370,37 @@ impl + AsMut> HugrMutInternals for T { Ok(std::mem::replace(cur, op.into())) } } + +#[cfg(test)] +mod test { + use crate::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::Noop, + hugr::internal::HugrMutInternals as _, + ops::handle::NodeHandle, + types::{Signature, Type}, + Direction, HugrView as _, + }; + + #[test] + fn insert_ports() { + let (nop, mut hugr) = { + let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); + let [nop_in] = builder.input_wires_arr(); + let nop = builder + .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) + .unwrap(); + let [nop_out] = nop.outputs_arr(); + ( + nop.node(), + builder.finish_hugr_with_outputs([nop_out]).unwrap(), + ) + }; + let [i, o] = hugr.get_io(hugr.root()).unwrap(); + hugr.insert_ports(nop, Direction::Incoming, 0, 2); + hugr.insert_ports(nop, Direction::Outgoing, 0, 2); + + assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into()))); + assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 2.into()))); + } +} From f096dd727bbadb3f9c6e01e3317434437bd526dc Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 11:25:03 +0000 Subject: [PATCH 14/61] fixes --- hugr-core/src/hugr/internal.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 1f67ff873a..51cb7af584 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -293,7 +293,7 @@ impl + AsMut> HugrMutInternals for T { Direction::Outgoing => self.base_hugr().graph.num_outputs(node.pg_index()), }; - let new_ports = self.add_ports(node, direction, amount as isize); + self.add_ports(node, direction, amount as isize); for swap_from_port in (index..old_num_ports).rev() { let swap_to_port = swap_from_port + amount; @@ -318,7 +318,7 @@ impl + AsMut> HugrMutInternals for T { .expect("Ports exist"); } } - index..new_ports.len() + index..index + amount } fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { @@ -374,7 +374,7 @@ impl + AsMut> HugrMutInternals for T { #[cfg(test)] mod test { use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, + builder::{Container, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::Noop, hugr::internal::HugrMutInternals as _, ops::handle::NodeHandle, @@ -390,6 +390,7 @@ mod test { let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) .unwrap(); + builder.add_other_wire(nop.node(), builder.output().node()); let [nop_out] = nop.outputs_arr(); ( nop.node(), @@ -397,10 +398,11 @@ mod test { ) }; let [i, o] = hugr.get_io(hugr.root()).unwrap(); - hugr.insert_ports(nop, Direction::Incoming, 0, 2); - hugr.insert_ports(nop, Direction::Outgoing, 0, 2); + assert_eq!(0..2, hugr.insert_ports(nop, Direction::Incoming, 0, 2)); + assert_eq!(1..3, hugr.insert_ports(nop, Direction::Outgoing, 1, 2)); assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into()))); - assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 2.into()))); + assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 0.into()))); + assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into()))); } } From 504afe507eee5203431993156f926fccd0046aa8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 11:33:42 +0000 Subject: [PATCH 15/61] with_prelude --- hugr-core/src/hugr/internal.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 51cb7af584..33d791266d 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -385,7 +385,8 @@ mod test { #[test] fn insert_ports() { let (nop, mut hugr) = { - let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); + let mut builder = + DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); let [nop_in] = builder.input_wires_arr(); let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) From 6cc87a465edb5bba832b48300b2d514c790fa2c0 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:26:20 +0000 Subject: [PATCH 16/61] works --- hugr-core/src/hugr/hugrmut.rs | 45 +--- hugr-passes/src/non_local.rs | 412 ++++++++++++++++++---------------- 2 files changed, 214 insertions(+), 243 deletions(-) diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index b76d1897fe..4056f36e61 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -4,7 +4,6 @@ use core::panic; use std::collections::HashMap; use std::sync::Arc; -use itertools::Itertools as _; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap}; @@ -12,7 +11,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; -use crate::{Direction, Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; use super::NodeMetadataMap; @@ -279,48 +278,6 @@ pub trait HugrMut: HugrMutInternals { fn extensions_mut(&mut self) -> &mut ExtensionRegistry { &mut self.hugr_mut().extensions } - - /// TODO perhaps these should be on HugrMut? - fn insert_incoming_port(&mut self, node: Node, index: usize) -> IncomingPort { - let _ = self - .add_ports(node, Direction::Incoming, 1) - .exactly_one() - .unwrap(); - - for (to, from) in (index..self.num_inputs(node)) - .map_into::() - .rev() - .tuple_windows() - { - let linked_outputs = self.linked_outputs(node, from).collect_vec(); - self.disconnect(node, from); - for (linked_node, linked_port) in linked_outputs { - self.connect(linked_node, linked_port, node, to); - } - } - index.into() - } - - /// TODO perhaps these should be on HugrMut? - fn insert_outgoing_port(&mut self, node: Node, index: usize) -> OutgoingPort { - let _ = self - .add_ports(node, Direction::Outgoing, 1) - .exactly_one() - .unwrap(); - - for (to, from) in (index..self.num_outputs(node)) - .map_into::() - .rev() - .tuple_windows() - { - let linked_inputs = self.linked_inputs(node, from).collect_vec(); - self.disconnect(node, from); - for (linked_node, linked_port) in linked_inputs { - self.connect(node, to, linked_node, linked_port); - } - } - index.into() - } } /// Records the result of inserting a Hugr or view diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 8600f19c50..20344bc6a6 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -7,7 +7,7 @@ use std::{ }; //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions -use itertools::{Either, Itertools as _}; +use itertools::Itertools as _; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -18,7 +18,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, - HugrView, IncomingPort, Node, PortIndex, Wire, + Direction, HugrView, IncomingPort, Node, PortIndex, Wire, }; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -95,24 +95,34 @@ struct WorkItem { } #[derive(Clone, Default, Debug)] -struct ParentSourceMap(HashMap>); +struct ParentSourceMap(HashMap>); impl ParentSourceMap { - // fn contains_parent(&self, parent: Node) -> bool { - // self.0.contains_key(&parent) - // } - fn insert_sources_in_parent( &mut self, parent: Node, - sources: impl IntoIterator, + sources: impl IntoIterator, ) { debug_assert!(!self.0.contains_key(&parent)); - self.0.entry(parent).or_default().extend(sources); + self.0 + .entry(parent) + .or_default() + .extend(sources.into_iter().map(|(s, p, t)| (s, (p, t)))); } - fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option { - self.0.get(&parent).and_then(|m| m.get(&source).cloned()) + fn get_source_in_parent( + &self, + parent: Node, + source: Wire, + ref hugr: impl HugrView, + ) -> (Wire, Type) { + let r @ (w, _) = self + .0 + .get(&parent) + .and_then(|m| m.get(&source).cloned()) + .unwrap(); + debug_assert_eq!(hugr.get_parent(w.node()).unwrap(), parent); + r } fn thread_dataflow_parent( @@ -121,26 +131,122 @@ impl ParentSourceMap { parent: Node, start_port_index: usize, sources: impl IntoIterator, - ) -> impl Iterator { - let [input_n, _] = hugr.get_io(parent).unwrap(); - let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { + ) { + let (source_wires, source_types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); + let input_wires = { + let [input_n, _] = hugr.get_io(parent).unwrap(); + let Some(mut input) = hugr.get_optype(input_n).as_input().cloned() else { + panic!("impossible") + }; + vec_insert(input.types.to_mut(), source_types.clone(), start_port_index); + hugr.replace_op(input_n, input).unwrap(); + hugr.insert_ports( + input_n, + Direction::Outgoing, + start_port_index, + source_wires.len(), + ) + .map(move |new_port| Wire::new(input_n, new_port)) + .collect_vec() + }; + self.insert_sources_in_parent( + parent, + itertools::izip!(source_wires, input_wires, source_types), + ); + } +} + +#[derive(Clone, Debug)] +struct ControlWorkItem { + output_node: Node, + variant_source_prefixes: Vec>, +} + +impl ControlWorkItem { + fn go(self, hugr: &mut impl HugrMut, psm: &ParentSourceMap) { + let parent = hugr.get_parent(self.output_node).unwrap(); + let Some(mut output) = hugr.get_optype(self.output_node).as_output().cloned() else { panic!("impossible") }; - let mut input_wires = vec![]; - self.0 - .entry(parent) - .or_default() - .extend(sources.into_iter().enumerate().map(|(i, (source, ty))| { - input.types.to_mut().insert(start_port_index + i, ty); - let input_wire = Wire::new( - input_n, - hugr.insert_outgoing_port(input_n, start_port_index + i), - ); - input_wires.push(input_wire); - (source, input_wire) - })); - hugr.replace_op(input_n, input).unwrap(); - input_wires.into_iter() + let mut needed_sources = BTreeMap::new(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = hugr + .get_optype(self.output_node) + .port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum() else { + panic!("impossible") + }; + + let mut type_for_source = |source: &Wire| { + let (w, t) = psm.get_source_in_parent(parent, *source, &hugr); + let replaced = needed_sources.insert(*source, (w, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (w, t.clone()))); + t + }; + let old_sum_rows: Vec = sum_type + .variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(self.variant_source_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_sources, old_tys)| { + new_sources + .iter() + .map(&mut type_for_source) + .chain(old_tys.iter().cloned()) + .collect_vec() + .into() + }) + .collect_vec(); + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + needed_sources + .values() + .map(|(_, t)| t.clone()) + .collect_vec(), + new_control_type.clone(), + ) + .unwrap(); + for (i, new_sources) in self.variant_source_prefixes.into_iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = new_sources + .into_iter() + .map(|s| { + case_inputs[old_sum_rows[i].len() + + needed_sources + .iter() + .find_position(|(&w, _)| w == s) + .unwrap() + .0] + }) + .collect_vec(); + args.extend(&case_inputs[..old_sum_rows[i].len()]); + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(parent, cond).new_root; + let (old_output_source_node, old_output_source_port) = + hugr.single_linked_output(self.output_node, 0).unwrap(); + debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); + hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); + for (i, &(w, _)) in needed_sources.values().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } + hugr.disconnect(self.output_node, IncomingPort::from(0)); + hugr.connect(cond_node, 0, self.output_node, 0); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(self.output_node, output).unwrap(); } } @@ -149,21 +255,19 @@ struct ThreadState<'a> { parent_source_map: ParentSourceMap, needs: &'a BBNeedsSourcesMap, worklist: Vec, + control_worklist: Vec, } impl<'a> ThreadState<'a> { delegate! { to self.parent_source_map { - // fn contains_parent(&self, parent: Node) -> bool; - // fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; - fn insert_sources_in_parent(&mut self, parent: Node, sources: impl IntoIterator); fn thread_dataflow_parent( &mut self, hugr: &mut impl HugrMut, parent: Node, start_port_index: usize, sources: impl IntoIterator, - ) -> impl Iterator; + ); } } @@ -172,6 +276,7 @@ impl<'a> ThreadState<'a> { parent_source_map: Default::default(), needs: bbnsm, worklist: vec![], + control_worklist: vec![], } } @@ -189,14 +294,20 @@ impl<'a> ThreadState<'a> { for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); - hugr.get_optype(succ_n).is_dataflow_block().then_some((out_p.index(), succ_n)) + hugr.get_optype(succ_n) + .is_dataflow_block() + .then_some((out_p.index(), succ_n)) }) { let succ_needs = &self.needs[&succ_n]; let succ_needs_source_indices = succ_needs .iter() .map(|(&w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) .collect_vec(); - let succ_needs_tys = succ_needs_source_indices.iter().copied().map(|x| sources[x].1.clone()).collect_vec(); + let succ_needs_tys = succ_needs_source_indices + .iter() + .copied() + .map(|x| sources[x].1.clone()) + .collect_vec(); vec_prepend(this_dfb.sum_rows[this_p].to_mut(), succ_needs_tys); nsrp[this_p] = succ_needs_source_indices; } @@ -204,88 +315,28 @@ impl<'a> ThreadState<'a> { nsrp }; - let input_wires = self - .thread_dataflow_parent(hugr, node, 0, sources.clone()) - .collect_vec(); + self.thread_dataflow_parent(hugr, node, 0, sources.clone()); let [_, o] = hugr.get_io(node).unwrap(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = - hugr.get_optype(o).port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; - - let old_sum_rows: Vec = sum_type - .variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows: Vec = - itertools::zip_eq(new_sum_row_prefixes.iter(), old_sum_rows.iter()) - .map(|(new_source_indices, old_tys)| { - new_source_indices.into_iter() - .map(|&x| sources[x].1.clone()) - .chain(old_tys.iter().cloned()) - .collect_vec() - .into() - }) - .collect_vec(); - - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = ConditionalBuilder::new( - old_sum_rows.clone(), - types.clone(), - new_control_type.clone(), - ) - .unwrap(); - for (i, new_source_indices) in new_sum_row_prefixes.into_iter().enumerate() { - let mut case = cond.case_builder(i).unwrap(); - let case_inputs = case.input_wires().collect_vec(); - let mut args = vec![]; - for source_i in new_source_indices { - args.push(case_inputs[old_sum_rows[i].len() + source_i]); - } - - args.extend(&case_inputs[..old_sum_rows[i].len()]); - - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); - } - (cond.finish_hugr().unwrap(), new_control_type) - }; - let cond_node = hugr.insert_hugr(node, cond).new_root; - let (n, p) = hugr.single_linked_output(o, 0).unwrap(); - hugr.connect(n, p, cond_node, 0); - for (i, w) in input_wires.into_iter().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); - } - hugr.disconnect(o, IncomingPort::from(0)); - hugr.connect(cond_node, 0, o, 0); - let mut output = hugr.get_optype(o).as_output().unwrap().clone(); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(o, output).unwrap(); - dbg!(hugr.single_linked_output(o, 0)); + self.control_worklist.push(ControlWorkItem { + output_node: o, + variant_source_prefixes: new_sum_row_prefixes + .into_iter() + .map(|v| v.into_iter().map(|i| sources[i].0.clone()).collect_vec()) + .collect_vec(), + }); } fn do_cfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { let types = sources.iter().map(|x| x.1.clone()).collect_vec(); { let mut cfg = hugr.get_optype(node).as_cfg().unwrap().clone(); - let inputs = cfg.signature.input.to_mut(); - let old_inputs = mem::replace(inputs, types); - inputs.extend(old_inputs); + vec_insert(cfg.signature.input.to_mut(), types, 0); hugr.replace_op(node, cfg).unwrap(); } - let new_cond_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, i)) - .collect_vec(); - self.insert_sources_in_parent(node, iter::empty()); + let new_cond_ports = hugr + .insert_ports(node, Direction::Incoming, 0, sources.len()) + .map_into(); self.worklist .extend(mk_workitems(node, sources, new_cond_ports)) } @@ -293,16 +344,20 @@ impl<'a> ThreadState<'a> { fn do_dfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { let mut dfg = hugr.get_optype(node).as_dfg().unwrap().clone(); let start_new_port_index = dfg.signature.input().len(); - let new_dfg_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) - .collect_vec(); + let new_dfg_ports = hugr + .insert_ports( + node, + Direction::Incoming, + start_new_port_index, + sources.len(), + ) + .map_into(); dfg.signature .input .to_mut() .extend(sources.iter().map(|x| x.1.clone())); hugr.replace_op(node, dfg).unwrap(); - let _ = - self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); self.worklist .extend(mk_workitems(node, sources, new_dfg_ports)); } @@ -314,10 +369,14 @@ impl<'a> ThreadState<'a> { .to_mut() .extend(sources.iter().map(|x| x.1.clone())); hugr.replace_op(node, cond).unwrap(); - let new_cond_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) - .collect_vec(); - self.insert_sources_in_parent(node, iter::empty()); + let new_cond_ports = hugr + .insert_ports( + node, + Direction::Incoming, + start_new_port_index, + sources.len(), + ) + .map_into(); self.worklist .extend(mk_workitems(node, sources, new_cond_ports)) } @@ -330,95 +389,48 @@ impl<'a> ThreadState<'a> { .to_mut() .extend(sources.iter().map(|x| x.1.clone())); hugr.replace_op(node, case).unwrap(); - let _ = self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); + self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); } fn do_tailloop(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); let types = sources.iter().map(|x| x.1.clone()).collect_vec(); - let start_port_index = tailloop.just_inputs.len(); { - tailloop.just_inputs.to_mut().extend(types.clone()); + vec_prepend(tailloop.just_inputs.to_mut(), types.clone()); hugr.replace_op(node, tailloop).unwrap(); } - let tailloop_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) - .collect_vec(); + let tailloop_ports = hugr + .insert_ports(node, Direction::Incoming, 0, sources.len()) + .map_into(); - let input_wires = self - .thread_dataflow_parent(hugr, node, start_port_index, sources.clone()) - .collect_vec(); + self.thread_dataflow_parent(hugr, node, 0, sources.clone()); let [_, o] = hugr.get_io(node).unwrap(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = - hugr.get_optype(o).port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; - - let old_sum_rows: Vec = sum_type - .variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows = { - let mut v = old_sum_rows.clone(); - v[TailLoop::CONTINUE_TAG] - .to_mut() - .extend(types.iter().cloned()); - v - }; - - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = - ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()) - .unwrap(); - for i in 0..2 { - let mut case = cond.case_builder(i).unwrap(); - let inputs = { - let all_inputs = case.input_wires(); - if i == TailLoop::CONTINUE_TAG { - Either::Left(all_inputs) - } else { - Either::Right(all_inputs.into_iter().dropping_back(types.len())) - } - }; - - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); - } - (cond.finish_hugr().unwrap(), new_control_type) + let new_sum_row_prefixes = { + let mut v = vec![vec![]; 2]; + v[TailLoop::CONTINUE_TAG].extend(sources.iter().map(|x| x.0)); + v }; - let cond_node = hugr.insert_hugr(node, cond).new_root; - let (n, p) = hugr.single_linked_output(o, 0).unwrap(); - hugr.connect(n, p, cond_node, 0); - for (i, w) in input_wires.into_iter().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); - } - hugr.disconnect(o, IncomingPort::from(0)); - hugr.connect(cond_node, 0, o, 0); - let mut output = hugr.get_optype(o).as_output().unwrap().clone(); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(o, output).unwrap(); + self.control_worklist.push(ControlWorkItem { + output_node: o, + variant_source_prefixes: new_sum_row_prefixes, + }); self.worklist .extend(mk_workitems(node, sources, tailloop_ports)) } - fn finish(self, _hugr: &mut impl HugrMut) -> (Vec, ParentSourceMap) { - (self.worklist, self.parent_source_map) + fn finish( + self, + _hugr: &mut impl HugrMut, + ) -> (Vec, ParentSourceMap, Vec) { + (self.worklist, self.parent_source_map, self.control_worklist) } } fn thread_sources( hugr: &mut impl HugrMut, bb_needs_sources_map: &BBNeedsSourcesMap, -) -> (Vec, ParentSourceMap) { +) -> (Vec, ParentSourceMap, Vec) { let mut state = ThreadState::new(bb_needs_sources_map); for (&bb, sources) in bb_needs_sources_map { let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); @@ -516,11 +528,6 @@ impl BBNeedsSourcesMapBuilder { .children(cfg) .filter(|&child| hugr.get_optype(child).is_dataflow_block()) .collect_vec(); - - // let mut dfb_needs_map: HashMap<_, _> = dfbs - // .iter() - // .map(|&n| (n, self.0.get(&n).cloned().unwrap_or_default())) - // .collect(); loop { let mut any_change = false; for &dfb in dfbs.iter() { @@ -627,11 +634,11 @@ pub fn remove_nonlocal_edges( } } - let (parent_source_map, worklist) = { + let (parent_source_map, worklist, control_worklist) = { let mut worklist = nonlocal_edges_map.into_values().collect_vec(); - let (wl, psm) = thread_sources(hugr, &bb_needs_sources_map); + let (wl, psm, control_worklist) = thread_sources(hugr, &bb_needs_sources_map); worklist.extend(wl); - (psm, worklist) + (psm, worklist, control_worklist) }; for wi in worklist { @@ -640,28 +647,35 @@ pub fn remove_nonlocal_edges( wi.source } else { parent_source_map - .get_source_in_parent(parent, wi.source) - .unwrap() + .get_source_in_parent(parent, wi.source, &hugr) + .0 }; debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); hugr.disconnect(wi.target.0, wi.target.1); hugr.connect(source.node(), source.source(), wi.target.0, wi.target.1); } + for cwi in control_worklist { + cwi.go(hugr, &parent_source_map) + } + Ok(()) } fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { - let mut old_v = mem::replace(v, ts.into_iter().collect()); - v.extend(old_v.drain(..)); + vec_insert(v, ts, 0) +} + +fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { + let mut old_v_iter = mem::replace(v, vec![]).into_iter(); + v.extend(old_v_iter.by_ref().take(index).chain(ts)); + v.extend(old_v_iter); } #[cfg(test)] mod test { use hugr_core::{ - builder::{ - DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, - }, + builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, extension::prelude::{bool_t, either_type, Noop}, ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, @@ -722,7 +736,7 @@ mod test { let [w0] = outer.input_wires_arr(); let [w1] = { let inner = outer - .dfg_builder(Signature::new(type_row![], bool_t()), []) + .dfg_builder(Signature::new_endo(bool_t()), [w0]) .unwrap(); inner.finish_with_outputs([w0]).unwrap().outputs_arr() }; From 461a5ab6178c8ffccccd484935b52a90133cd3e8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:27:53 +0000 Subject: [PATCH 17/61] wip --- hugr-passes/src/non_local.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 20344bc6a6..0293323a9e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -322,7 +322,7 @@ impl<'a> ThreadState<'a> { output_node: o, variant_source_prefixes: new_sum_row_prefixes .into_iter() - .map(|v| v.into_iter().map(|i| sources[i].0.clone()).collect_vec()) + .map(|v| v.into_iter().map(|i| sources[i].0).collect_vec()) .collect_vec(), }); } @@ -667,7 +667,7 @@ fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { } fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { - let mut old_v_iter = mem::replace(v, vec![]).into_iter(); + let mut old_v_iter = mem::take(v).into_iter(); v.extend(old_v_iter.by_ref().take(index).chain(ts)); v.extend(old_v_iter); } From 3a07aa3f1d6c8e52bfe2ac03d3741f815220fb38 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:31:57 +0000 Subject: [PATCH 18/61] get get_optype_mut --- hugr-core/src/hugr/internal.rs | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index b797335c9a..54d9004cdc 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -266,13 +266,6 @@ pub trait HugrMutInternals: RootTagged { } self.hugr_mut().replace_op(node, op) } - - /// TODO docs - fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { - panic_invalid_node(self, node); - // TODO refuse if node == self.root() because tag might be violated - self.hugr_mut().get_optype_mut(node) - } } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -371,14 +364,10 @@ impl + AsMut> HugrMutInternals for T { fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check Ok(std::mem::replace( - self.hugr_mut().get_optype_mut(node)?, + self.hugr_mut().op_types.get_mut(node.pg_index()), op.into(), )) } - - fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { - Ok(self.hugr_mut().op_types.get_mut(node.pg_index())) - } } #[cfg(test)] From 9d251df207c93cdfb67e97ee062ad782d5ac8dc1 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:32:25 +0000 Subject: [PATCH 19/61] fix merge --- hugr-core/src/hugr/internal.rs | 37 ---------------------------------- 1 file changed, 37 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 54d9004cdc..3e98ac0f20 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -406,40 +406,3 @@ mod test { assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into()))); } } - -#[cfg(test)] -mod test { - use crate::{ - builder::{Container, DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::Noop, - hugr::internal::HugrMutInternals as _, - ops::handle::NodeHandle, - types::{Signature, Type}, - Direction, HugrView as _, - }; - - #[test] - fn insert_ports() { - let (nop, mut hugr) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); - let [nop_in] = builder.input_wires_arr(); - let nop = builder - .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) - .unwrap(); - builder.add_other_wire(nop.node(), builder.output().node()); - let [nop_out] = nop.outputs_arr(); - ( - nop.node(), - builder.finish_hugr_with_outputs([nop_out]).unwrap(), - ) - }; - let [i, o] = hugr.get_io(hugr.root()).unwrap(); - assert_eq!(0..2, hugr.insert_ports(nop, Direction::Incoming, 0, 2)); - assert_eq!(1..3, hugr.insert_ports(nop, Direction::Outgoing, 1, 2)); - - assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into()))); - assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 0.into()))); - assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into()))); - } -} From 9a8a5e257e34abb9bc71c10f17d04ec42347bf9c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:33:51 +0000 Subject: [PATCH 20/61] tweak --- hugr-core/src/hugr/internal.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 3e98ac0f20..75b0aab1d2 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -363,10 +363,8 @@ impl + AsMut> HugrMutInternals for T { fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check - Ok(std::mem::replace( - self.hugr_mut().op_types.get_mut(node.pg_index()), - op.into(), - )) + let cur = self.hugr_mut().op_types.get_mut(node.pg_index()); + Ok(std::mem::replace(cur, op.into())) } } From dd4caa0f0fa8ad7c85f5d8556fd8fe5408c8d66f Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:52:29 +0000 Subject: [PATCH 21/61] fmt --- hugr-passes/src/non_local.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 0293323a9e..cba63b5496 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -114,7 +114,7 @@ impl ParentSourceMap { &self, parent: Node, source: Wire, - ref hugr: impl HugrView, + hugr: impl HugrView, ) -> (Wire, Type) { let r @ (w, _) = self .0 @@ -470,7 +470,8 @@ impl BBNeedsSourcesMapBuilder { self.0.entry(bb).or_default().insert(source, ty); } - fn extend_parent_needs_for(&mut self, ref hugr: impl HugrView, child: Node) -> bool { + fn extend_parent_needs_for(&mut self, hugr: impl HugrView, child: Node) -> bool { + let hugr = &hugr; let parent = hugr.get_parent(child).unwrap(); let parent_needs = self .0 From 6dabc6bd5f1b3bd52b26ffe834d39638bb670da9 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:53:45 +0000 Subject: [PATCH 22/61] with_prelude --- hugr-passes/src/non_local.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index cba63b5496..4a654ec217 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -881,7 +881,7 @@ mod test { Type::UNIT, ], vec![Type::UNIT, other_output_type.clone()], - )) + ).with_prelude()) .unwrap(); let [b, nl1, nl2, unit] = outer.input_wires_arr(); let [unit, out] = { From 322facfc4cb29d64802e8d59fcc459c9b47c3b81 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 12 Feb 2025 10:19:51 +0000 Subject: [PATCH 23/61] wip --- devenv.lock | 49 +++---- hugr-passes/src/non_local.rs | 253 ++++++++++++++++++----------------- 2 files changed, 159 insertions(+), 143 deletions(-) diff --git a/devenv.lock b/devenv.lock index d606d21055..99f06fb280 100644 --- a/devenv.lock +++ b/devenv.lock @@ -51,10 +51,31 @@ "type": "github" } }, + "git-hooks": { + "inputs": { + "flake-compat": "flake-compat", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1737465171, + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "9364dc02281ce2d37a1f55b6e51f7c0f65a75f17", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, "gitignore": { "inputs": { "nixpkgs": [ - "pre-commit-hooks", + "git-hooks", "nixpkgs" ] }, @@ -101,34 +122,16 @@ "type": "github" } }, - "pre-commit-hooks": { - "inputs": { - "flake-compat": "flake-compat", - "gitignore": "gitignore", - "nixpkgs": [ - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1735882644, - "owner": "cachix", - "repo": "pre-commit-hooks.nix", - "rev": "a5a961387e75ae44cc20f0a57ae463da5e959656", - "type": "github" - }, - "original": { - "owner": "cachix", - "repo": "pre-commit-hooks.nix", - "type": "github" - } - }, "root": { "inputs": { "devenv": "devenv", "fenix": "fenix", + "git-hooks": "git-hooks", "nixpkgs": "nixpkgs", "nixpkgs-stable": "nixpkgs-stable", - "pre-commit-hooks": "pre-commit-hooks" + "pre-commit-hooks": [ + "git-hooks" + ] } }, "rust-analyzer-src": { diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 4a654ec217..ad62e8a99d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -7,7 +7,7 @@ use std::{ }; //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions -use itertools::Itertools as _; +use itertools::{Either, Itertools as _}; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -94,6 +94,22 @@ struct WorkItem { ty: Type, } +impl WorkItem { + pub fn go(self, hugr: &mut impl HugrMut, parent_source_map: &ParentSourceMap) { + let parent = hugr.get_parent(self.target.0).unwrap(); + let source = if hugr.get_parent(self.source.node()).unwrap() == parent { + self.source + } else { + parent_source_map + .get_source_in_parent(parent, self.source, &hugr) + .0 + }; + debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(self.target.0)); + hugr.disconnect(self.target.0, self.target.1); + hugr.connect(source.node(), source.source(), self.target.0, self.target.1); + } +} + #[derive(Clone, Default, Debug)] struct ParentSourceMap(HashMap>); @@ -298,10 +314,8 @@ impl<'a> ThreadState<'a> { .is_dataflow_block() .then_some((out_p.index(), succ_n)) }) { - let succ_needs = &self.needs[&succ_n]; - let succ_needs_source_indices = succ_needs - .iter() - .map(|(&w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) + let succ_needs_source_indices = self.needs.get(succ_n) + .map(|(w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) .collect_vec(); let succ_needs_tys = succ_needs_source_indices .iter() @@ -432,7 +446,7 @@ fn thread_sources( bb_needs_sources_map: &BBNeedsSourcesMap, ) -> (Vec, ParentSourceMap, Vec) { let mut state = ThreadState::new(bb_needs_sources_map); - for (&bb, sources) in bb_needs_sources_map { + for (bb, sources) in bb_needs_sources_map { let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); match hugr.get_optype(bb).clone() { OpType::DFG(_) => state.do_dfg(hugr, bb, sources), @@ -460,82 +474,105 @@ fn mk_workitems( }) } -type BBNeedsSourcesMap = HashMap>; - #[derive(Debug, Default, Clone)] -struct BBNeedsSourcesMapBuilder(BBNeedsSourcesMap); +struct BBNeedsSourcesMap(BTreeMap>); -impl BBNeedsSourcesMapBuilder { - fn insert(&mut self, bb: Node, source: Wire, ty: Type) { - self.0.entry(bb).or_default().insert(source, ty); +struct NeedsSourcesMapIter<'a>(<&'a BTreeMap> as IntoIterator>::IntoIter); + +impl<'a> Iterator for NeedsSourcesMapIter<'a> { + type Item = (Node, &'a BTreeMap); + + fn next(&mut self) -> Option { + self.0.next().map(|(&n,bt)| (n,bt)) } +} - fn extend_parent_needs_for(&mut self, hugr: impl HugrView, child: Node) -> bool { - let hugr = &hugr; - let parent = hugr.get_parent(child).unwrap(); - let parent_needs = self - .0 - .get(&child) - .into_iter() - .flat_map(move |m| { - m.iter() - .filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) - .map(|(&w, ty)| (w, ty.clone())) - }) - .collect_vec(); - let any = !parent_needs.is_empty(); - if any { - self.0.entry(parent).or_default().extend(parent_needs); +impl<'a> IntoIterator for &'a BBNeedsSourcesMap { + type Item = as Iterator>::Item; + + type IntoIter = NeedsSourcesMapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + NeedsSourcesMapIter((&self.0).into_iter()) + } +} + +impl BBNeedsSourcesMap { + fn insert(&mut self, node: Node, source: Wire, ty: Type) -> bool { + self.0.entry(node).or_default().insert(source, ty).is_none() + } + + fn get(&self, node: Node) -> impl Iterator + '_ { + match self.0.get(&node) { + Some(x) => Either::Left(x.iter().map(|(&w, t)| (w, t))), + None => Either::Right(iter::empty()) + } + } + + delegate! { + to self.0 { + fn keys(&self) -> impl Iterator; + } + } +} + +#[derive(Debug, Clone)] +struct BBNeedsSourcesMapBuilder { + hugr: H, + needs_sources: BBNeedsSourcesMap, +} + +impl BBNeedsSourcesMapBuilder { + fn new(hugr: H) -> Self { + Self { + hugr, + needs_sources:Default::default(), } - any } - fn finish(mut self, hugr: impl HugrView) -> BBNeedsSourcesMap { + fn insert(&mut self, mut parent: Node, source: Wire, ty: Type) { + let source_parent = self.hugr.get_parent(source.node()).unwrap(); + loop { + if source_parent == parent { + break; + } + if !self.needs_sources.insert(parent, source, ty.clone()) { + break; + } + let Some(parent_of_parent) = self.hugr.get_parent(parent) else { + break; + }; + parent = parent_of_parent + } + } + + fn finish(mut self) -> BBNeedsSourcesMap { { - let conds = self - .0 - .keys() - .copied() - .filter(|&n| hugr.get_optype(n).is_conditional()) - .collect_vec(); - for cond in conds { - let cases = hugr - .children(cond) - .filter(|&child| hugr.get_optype(child).is_case()) - .collect_vec(); - let all_needed: BTreeMap<_, _> = cases - .iter() - .flat_map(|&case| { - let case_needed = self.0.get(&case); - case_needed - .into_iter() - .flat_map(|m| m.iter().map(|(&w, ty)| (w, ty.clone()))) - }) - .collect(); - for case in cases { - let _ = self.0.insert(case, all_needed.clone()); + let conds = self.needs_sources.keys().copied().filter(|&n| self.hugr.get_optype(n).is_conditional()).collect_vec(); + for n in conds { + let n_needs = self.needs_sources.get(n).map(|(w,ty)| (w, ty.clone())).collect_vec(); + for case in self.hugr + .children(n) + .filter(|&child| self.hugr.get_optype(child).is_case()) { + for (w, ty) in n_needs.iter() { + self.needs_sources.insert(case, *w, ty.clone()); + } } } } { - let cfgs = self - .0 - .keys() - .copied() - .filter(|&n| hugr.get_optype(n).is_cfg() && self.0.contains_key(&n)) - .collect_vec(); - for cfg in cfgs { - let dfbs = hugr - .children(cfg) - .filter(|&child| hugr.get_optype(child).is_dataflow_block()) + let cfgs = self.needs_sources.keys().copied().filter(|&n| self.hugr.get_optype(n).is_cfg()).collect_vec(); + for n in cfgs { + let dfbs = self.hugr + .children(n) + .filter(|&child| self.hugr.get_optype(child).is_dataflow_block()) .collect_vec(); loop { let mut any_change = false; for &dfb in dfbs.iter() { - for succ_n in hugr.output_neighbours(dfb) { - for (w, ty) in self.0.get(&succ_n).cloned().unwrap_or_default() { - any_change |= - self.0.entry(dfb).or_default().insert(w, ty).is_none(); + for succ_n in self.hugr.output_neighbours(dfb) { + for (w, ty) in self.needs_sources.get(succ_n).map(|(w,ty)| (w, ty.clone())).collect_vec() { + any_change |= self.needs_sources.insert(dfb, w, ty.clone()); } } } @@ -545,20 +582,34 @@ impl BBNeedsSourcesMapBuilder { } } } + self.needs_sources + } +} - self.0 +fn build_needs_sources_map(hugr: impl HugrView, nonlocal_edges: &HashMap) -> BBNeedsSourcesMap { + let mut bnsm = BBNeedsSourcesMapBuilder::new(&hugr); + for workitem in nonlocal_edges.values() { + let parent = hugr.get_parent(workitem.target.0).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.insert(parent, workitem.source, workitem.ty.clone()); } + bnsm.finish() } pub fn remove_nonlocal_edges( hugr: &mut impl HugrMut, root: Node, ) -> Result<(), NonLocalEdgesError> { + // First we collect all the non-local edges in the graph. We associate them to a WorkItem, which tracks: + // * the source of the non-local edge + // * the target of the non-local edge + // * the type of the non-local edge. Note that all non-local edges are + // value edges, so the type is well defined. let nonlocal_edges_map: HashMap = nonlocal_edges(&DescendantsGraph::::try_new(hugr, root)?) - .map(|target @ (node, inport)| { + .filter_map(|target @ (node, inport)| { let source = { - let (n, p) = hugr.single_linked_output(node, inport).unwrap(); + let (n, p) = hugr.single_linked_output(node, inport)?; Wire::new(n, p) }; debug_assert!( @@ -569,7 +620,7 @@ pub fn remove_nonlocal_edges( else { panic!("impossible") }; - (node, WorkItem { source, target, ty }) + Some((node, WorkItem { source, target, ty })) }) .collect(); @@ -577,45 +628,12 @@ pub fn remove_nonlocal_edges( return Ok(()); } - let bb_needs_sources_map = { - let nonlocal_sorted = { - let mut v = iter::successors(Some(vec![root]), |nodes| { - let children = nodes.iter().flat_map(|&n| hugr.children(n)).collect_vec(); - (!children.is_empty()).then_some(children) - }) - .flatten() - .filter_map(|n| nonlocal_edges_map.get(&n)) - .collect_vec(); - v.reverse(); - v - }; - let mut parent_set = HashSet::::new(); - // earlier items are deeper in the heirarchy - let mut parent_worklist = VecDeque::::new(); - let mut add_parent = |p, wl: &mut VecDeque<_>| { - if parent_set.insert(p) { - wl.push_back(p); - } - }; - let mut bnsm = BBNeedsSourcesMapBuilder::default(); - for workitem in nonlocal_sorted { - let parent = hugr.get_parent(workitem.target.0).unwrap(); - debug_assert!(hugr.get_parent(parent).is_some()); - bnsm.insert(parent, workitem.source, workitem.ty.clone()); - add_parent(parent, &mut parent_worklist); - } - - while let Some(bb_node) = parent_worklist.pop_front() { - let Some(parent) = hugr.get_parent(bb_node) else { - continue; - }; - if bnsm.extend_parent_needs_for(&hugr, bb_node) { - add_parent(parent, &mut parent_worklist); - } - } - bnsm.finish(&hugr) - }; + // We now compute the sources needed by each parent node. + // For a given non-local edge every intermediate node in the hierarchy + // between the source's parent and the target needs that source. + let bb_needs_sources_map = build_needs_sources_map(&hugr, &nonlocal_edges_map); + // TODO move this out-of-line #[cfg(debug_assertions)] { for (&n, wi) in nonlocal_edges_map.iter() { @@ -625,7 +643,7 @@ pub fn remove_nonlocal_edges( if hugr.get_parent(wi.source.node()).unwrap() == parent { break; } - assert!(bb_needs_sources_map[&parent].contains_key(&wi.source)); + assert!(bb_needs_sources_map.get(parent).find(|(w,_)| *w == wi.source).is_some()); m = parent; } } @@ -635,6 +653,11 @@ pub fn remove_nonlocal_edges( } } + // Here we mutate the HUGR; adding ports to parent nodes and their Input nodes. + // The result is: + // * parent_source_map: A map from parent and source to the wire that should substitute for that source in that parent. + // * worklist: a list of workitems. Each should be fulfilled by connecting the source, substituted through parent_source_map, to the target. + // * control_worklist: A list of control ports (i.e. 0th output port of DataflowBlock or TailLoop) that must be rewired. let (parent_source_map, worklist, control_worklist) = { let mut worklist = nonlocal_edges_map.into_values().collect_vec(); let (wl, psm, control_worklist) = thread_sources(hugr, &bb_needs_sources_map); @@ -643,17 +666,7 @@ pub fn remove_nonlocal_edges( }; for wi in worklist { - let parent = hugr.get_parent(wi.target.0).unwrap(); - let source = if hugr.get_parent(wi.source.node()).unwrap() == parent { - wi.source - } else { - parent_source_map - .get_source_in_parent(parent, wi.source, &hugr) - .0 - }; - debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); - hugr.disconnect(wi.target.0, wi.target.1); - hugr.connect(source.node(), source.source(), wi.target.0, wi.target.1); + wi.go(hugr, &parent_source_map) } for cwi in control_worklist { From 040f5dbd15f449f5767f0d1d105b7b8a8694a667 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 12:14:44 +0100 Subject: [PATCH 24/61] comments, reorder analysis before transform, add tests of vec_insert --- hugr-passes/src/non_local.rs | 335 +++++++++++++++++++---------------- 1 file changed, 179 insertions(+), 156 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 5d320bae89..1c6bfece2b 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -61,6 +61,169 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator(BTreeMap, Type>>); + +impl Default for BBNeedsSourcesMap { + fn default() -> Self { + Self(BTreeMap::default()) + } +} + +struct NeedsSourcesMapIter<'a, N>( + <&'a BTreeMap, Type>> as IntoIterator>::IntoIter, +); + +impl<'a, N> Iterator for NeedsSourcesMapIter<'a, N> { + type Item = (&'a N, &'a BTreeMap, Type>); + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl<'a, N: HugrNode> IntoIterator for &'a BBNeedsSourcesMap { + type Item = ::Item; + type IntoIter = NeedsSourcesMapIter<'a, N>; + + fn into_iter(self) -> Self::IntoIter { + NeedsSourcesMapIter(self.0.iter()) + } +} + +impl BBNeedsSourcesMap { + fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { + self.0.entry(node).or_default().insert(source, ty).is_none() + } + + fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { + match self.0.get(&node) { + Some(x) => Either::Left(x.iter()), + None => Either::Right(iter::empty()), + } + } + + delegate! { + to self.0 { + fn keys(&self) -> impl Iterator; + } + } +} + +#[derive(Debug, Clone)] +struct BBNeedsSourcesMapBuilder { + hugr: H, + needs_sources: BBNeedsSourcesMap, +} + +impl BBNeedsSourcesMapBuilder { + fn new(hugr: H) -> Self { + Self { + hugr, + needs_sources: Default::default(), + } + } + + fn insert(&mut self, mut parent: H::Node, source: Wire, ty: Type) { + let source_parent = self.hugr.get_parent(source.node()).unwrap(); + loop { + if source_parent == parent { + break; + } + if !self.needs_sources.insert(parent, source, ty.clone()) { + break; + } + let Some(parent_of_parent) = self.hugr.get_parent(parent) else { + break; + }; + parent = parent_of_parent + } + } + + fn finish(mut self) -> BBNeedsSourcesMap { + { + // Conditionals. Any `Case` needing an input, means the parent Conditional needs it too. + let conds = self + .needs_sources + .keys() + .copied() + .filter(|&n| self.hugr.get_optype(n).is_conditional()) + .collect_vec(); + for n in conds { + let n_needs = self + .needs_sources + .get(n) + .map(|(&w, ty)| (w, ty.clone())) + .collect_vec(); + for case in self + .hugr + .children(n) + .filter(|&child| self.hugr.get_optype(child).is_case()) + { + for (w, ty) in n_needs.iter() { + self.needs_sources.insert(case, *w, ty.clone()); + } + } + } + } + { + let cfgs = self + .needs_sources + .keys() + .copied() + .filter(|&n| self.hugr.get_optype(n).is_cfg()) + .collect_vec(); + for n in cfgs { + let dfbs = self + .hugr + .children(n) + .filter(|&child| self.hugr.get_optype(child).is_dataflow_block()) + .collect_vec(); + loop { + let mut any_change = false; + for &dfb in dfbs.iter() { + for succ_n in self.hugr.output_neighbours(dfb) { + for (w, ty) in self + .needs_sources + .get(succ_n) + .map(|(w, ty)| (*w, ty.clone())) + .collect_vec() + { + // Do we need something like: + // if w.node() == dfb: continue + any_change |= self.needs_sources.insert(dfb, w, ty.clone()); + } + } + } + if !any_change { + break; + } + } + } + } + self.needs_sources + } +} + +// Identify all required extra inputs (for both Dom and Ext edges) +fn build_needs_sources_map( + hugr: impl HugrView, + nonlocal_edges: &HashMap>, +) -> BBNeedsSourcesMap { + let mut bnsm = BBNeedsSourcesMapBuilder::new(&hugr); + for workitem in nonlocal_edges.values() { + let parent = hugr.get_parent(workitem.target.0).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.insert(parent, workitem.source, workitem.ty.clone()); + } + bnsm.finish() +} + +// Transformation: adding extra ports, and wiring them up =============================== + #[derive(derive_more::Error, derive_more::From, derive_more::Display, Debug, PartialEq)] #[non_exhaustive] pub enum NonLocalEdgesError { @@ -177,8 +340,8 @@ impl ParentSourceMap { #[derive(Clone, Debug)] struct ControlWorkItem { - output_node: N, - variant_source_prefixes: Vec>>, + output_node: N, // Output node of CFG / TailLoop + variant_source_prefixes: Vec>>, // prefixes to each element of Sum type } impl ControlWorkItem { @@ -512,160 +675,6 @@ fn mk_workitems( }) } -#[derive(Debug, Clone)] -struct BBNeedsSourcesMap(BTreeMap, Type>>); - -impl Default for BBNeedsSourcesMap { - fn default() -> Self { - Self(BTreeMap::default()) - } -} - -struct NeedsSourcesMapIter<'a, N>( - <&'a BTreeMap, Type>> as IntoIterator>::IntoIter, -); - -impl<'a, N> Iterator for NeedsSourcesMapIter<'a, N> { - type Item = (&'a N, &'a BTreeMap, Type>); - - fn next(&mut self) -> Option { - self.0.next() - } -} - -impl<'a, N: HugrNode> IntoIterator for &'a BBNeedsSourcesMap { - type Item = ::Item; - type IntoIter = NeedsSourcesMapIter<'a, N>; - - fn into_iter(self) -> Self::IntoIter { - NeedsSourcesMapIter(self.0.iter()) - } -} - -impl BBNeedsSourcesMap { - fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { - self.0.entry(node).or_default().insert(source, ty).is_none() - } - - fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { - match self.0.get(&node) { - Some(x) => Either::Left(x.iter()), - None => Either::Right(iter::empty()), - } - } - - delegate! { - to self.0 { - fn keys(&self) -> impl Iterator; - } - } -} - -#[derive(Debug, Clone)] -struct BBNeedsSourcesMapBuilder { - hugr: H, - needs_sources: BBNeedsSourcesMap, -} - -impl BBNeedsSourcesMapBuilder { - fn new(hugr: H) -> Self { - Self { - hugr, - needs_sources: Default::default(), - } - } - - fn insert(&mut self, mut parent: H::Node, source: Wire, ty: Type) { - let source_parent = self.hugr.get_parent(source.node()).unwrap(); - loop { - if source_parent == parent { - break; - } - if !self.needs_sources.insert(parent, source, ty.clone()) { - break; - } - let Some(parent_of_parent) = self.hugr.get_parent(parent) else { - break; - }; - parent = parent_of_parent - } - } - - fn finish(mut self) -> BBNeedsSourcesMap { - { - let conds = self - .needs_sources - .keys() - .copied() - .filter(|&n| self.hugr.get_optype(n).is_conditional()) - .collect_vec(); - for n in conds { - let n_needs = self - .needs_sources - .get(n) - .map(|(&w, ty)| (w, ty.clone())) - .collect_vec(); - for case in self - .hugr - .children(n) - .filter(|&child| self.hugr.get_optype(child).is_case()) - { - for (w, ty) in n_needs.iter() { - self.needs_sources.insert(case, *w, ty.clone()); - } - } - } - } - { - let cfgs = self - .needs_sources - .keys() - .copied() - .filter(|&n| self.hugr.get_optype(n).is_cfg()) - .collect_vec(); - for n in cfgs { - let dfbs = self - .hugr - .children(n) - .filter(|&child| self.hugr.get_optype(child).is_dataflow_block()) - .collect_vec(); - loop { - let mut any_change = false; - for &dfb in dfbs.iter() { - for succ_n in self.hugr.output_neighbours(dfb) { - for (w, ty) in self - .needs_sources - .get(succ_n) - .map(|(w, ty)| (*w, ty.clone())) - .collect_vec() - { - any_change |= self.needs_sources.insert(dfb, w, ty.clone()); - } - } - } - if !any_change { - break; - } - } - } - } - self.needs_sources - } -} - -fn build_needs_sources_map( - hugr: impl HugrView, - nonlocal_edges: &HashMap>, -) -> BBNeedsSourcesMap { - let mut bnsm = BBNeedsSourcesMapBuilder::new(&hugr); - for workitem in nonlocal_edges.values() { - let parent = hugr.get_parent(workitem.target.0).unwrap(); - debug_assert!(hugr.get_parent(parent).is_some()); - bnsm.insert(parent, workitem.source, workitem.ty.clone()); - } - bnsm.finish() -} - pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), NonLocalEdgesError> { // First we collect all the non-local edges in the graph. We associate them to a WorkItem, which tracks: // * the source of the non-local edge @@ -768,6 +777,20 @@ mod test { use super::*; + #[test] + fn vec_insert0() { + let mut v = vec![5,7,9]; + vec_insert(&mut v, [1,2], 0); + assert_eq!(v, [1,2,5,7,9]); + } + + #[test] + fn vec_insert1() { + let mut v = vec![5,7,9]; + vec_insert(&mut v, [1,2], 1); + assert_eq!(v, [5,1,2,7,9]); + } + #[test] fn ensures_no_nonlocal_edges() { let hugr = { From bd68a4b55d22e8746ed9ace55d8fd3a1e111e33e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 12:29:15 +0100 Subject: [PATCH 25/61] simplify nonlocal_edges: inports can have at most one connected outport --- hugr-passes/src/non_local.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 1c6bfece2b..d26e40fc6a 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -53,10 +53,8 @@ use hugr_core::{ pub fn nonlocal_edges(hugr: &H) -> impl Iterator + '_ { hugr.entry_descendants().flat_map(move |node| { hugr.in_value_types(node).filter_map(move |(in_p, _)| { - let parent = hugr.get_parent(node); - hugr.linked_outputs(node, in_p) - .any(|(neighbour_node, _)| parent != hugr.get_parent(neighbour_node)) - .then_some((node, in_p)) + let (src, _) = hugr.single_linked_output(node, in_p)?; + (hugr.get_parent(node) != hugr.get_parent(src)).then_some((node, in_p)) }) }) } From ad1714c7b2fbfeed7f798d35b3fc97b89add7f8c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 12:32:12 +0100 Subject: [PATCH 26/61] Remove NeedsSourcesMapIter/impl IntoIterator for &BBNeedsSourcesMap --- hugr-passes/src/non_local.rs | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index d26e40fc6a..511ac7798b 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -71,27 +71,6 @@ impl Default for BBNeedsSourcesMap { } } -struct NeedsSourcesMapIter<'a, N>( - <&'a BTreeMap, Type>> as IntoIterator>::IntoIter, -); - -impl<'a, N> Iterator for NeedsSourcesMapIter<'a, N> { - type Item = (&'a N, &'a BTreeMap, Type>); - - fn next(&mut self) -> Option { - self.0.next() - } -} - -impl<'a, N: HugrNode> IntoIterator for &'a BBNeedsSourcesMap { - type Item = ::Item; - type IntoIter = NeedsSourcesMapIter<'a, N>; - - fn into_iter(self) -> Self::IntoIter { - NeedsSourcesMapIter(self.0.iter()) - } -} - impl BBNeedsSourcesMap { fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { self.0.entry(node).or_default().insert(source, ty).is_none() @@ -645,7 +624,7 @@ fn thread_sources( Vec>, ) { let mut state = ThreadState::new(bb_needs_sources_map); - for (&bb, sources) in bb_needs_sources_map { + for (&bb, sources) in bb_needs_sources_map.0.iter() { let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); match hugr.get_optype(bb).clone() { OpType::DFG(_) => state.do_dfg(hugr, bb, sources), From f94c6ad1856891802afd58593dd580b2375a39fd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 13:53:51 +0100 Subject: [PATCH 27/61] BBNeedsSourcesMap: work harder in insert(), remove finish --- hugr-passes/src/non_local.rs | 103 ++++++++++------------------------- 1 file changed, 29 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 511ac7798b..387516e9ad 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -106,81 +106,36 @@ impl BBNeedsSourcesMapBuilder { fn insert(&mut self, mut parent: H::Node, source: Wire, ty: Type) { let source_parent = self.hugr.get_parent(source.node()).unwrap(); - loop { - if source_parent == parent { - break; - } + while source_parent != parent { if !self.needs_sources.insert(parent, source, ty.clone()) { break; } - let Some(parent_of_parent) = self.hugr.get_parent(parent) else { - break; - }; - parent = parent_of_parent - } - } - - fn finish(mut self) -> BBNeedsSourcesMap { - { - // Conditionals. Any `Case` needing an input, means the parent Conditional needs it too. - let conds = self - .needs_sources - .keys() - .copied() - .filter(|&n| self.hugr.get_optype(n).is_conditional()) - .collect_vec(); - for n in conds { - let n_needs = self - .needs_sources - .get(n) - .map(|(&w, ty)| (w, ty.clone())) - .collect_vec(); - for case in self - .hugr - .children(n) - .filter(|&child| self.hugr.get_optype(child).is_case()) - { - for (w, ty) in n_needs.iter() { - self.needs_sources.insert(case, *w, ty.clone()); - } + if self.hugr.get_optype(parent).is_conditional() { + // One of these we must have just done on the previous iteration + for case in self.hugr.children(parent) { + // Full recursion unnecessary as we've just added parent: + self.needs_sources.insert(case, source, ty.clone()); } } - } - { - let cfgs = self - .needs_sources - .keys() - .copied() - .filter(|&n| self.hugr.get_optype(n).is_cfg()) - .collect_vec(); - for n in cfgs { - let dfbs = self - .hugr - .children(n) - .filter(|&child| self.hugr.get_optype(child).is_dataflow_block()) - .collect_vec(); - loop { - let mut any_change = false; - for &dfb in dfbs.iter() { - for succ_n in self.hugr.output_neighbours(dfb) { - for (w, ty) in self - .needs_sources - .get(succ_n) - .map(|(w, ty)| (*w, ty.clone())) - .collect_vec() - { - // Do we need something like: - // if w.node() == dfb: continue - any_change |= self.needs_sources.insert(dfb, w, ty.clone()); - } - } - } - if !any_change { - break; - } + // this will panic if source_parent is not an ancestor of target + let parent_parent = self.hugr.get_parent(parent).unwrap(); + if self.hugr.get_optype(parent).is_dataflow_block() { + assert!(self.hugr.get_optype(parent_parent).is_cfg()); + for pred in self.hugr.input_neighbours(parent).collect::>() { + self.insert(pred, source, ty.clone()); } + if Some(parent) != self.hugr.children(parent_parent).next() { + // Recursive calls on predecessors will have traced back to entry block + // (or source_parent itself if a dominating Basic Block) + break; + } + // We've just added to entry node - so must add to CFG as well } + parent = parent_parent; } + } + + fn finish(self) -> BBNeedsSourcesMap { self.needs_sources } } @@ -317,7 +272,7 @@ impl ParentSourceMap { #[derive(Clone, Debug)] struct ControlWorkItem { - output_node: N, // Output node of CFG / TailLoop + output_node: N, // Output node of CFG / TailLoop variant_source_prefixes: Vec>>, // prefixes to each element of Sum type } @@ -756,16 +711,16 @@ mod test { #[test] fn vec_insert0() { - let mut v = vec![5,7,9]; - vec_insert(&mut v, [1,2], 0); - assert_eq!(v, [1,2,5,7,9]); + let mut v = vec![5, 7, 9]; + vec_insert(&mut v, [1, 2], 0); + assert_eq!(v, [1, 2, 5, 7, 9]); } #[test] fn vec_insert1() { - let mut v = vec![5,7,9]; - vec_insert(&mut v, [1,2], 1); - assert_eq!(v, [5,1,2,7,9]); + let mut v = vec![5, 7, 9]; + vec_insert(&mut v, [1, 2], 1); + assert_eq!(v, [5, 1, 2, 7, 9]); } #[test] From 6ea566def241c1be4c07b2b0ae743efcc3ef1191 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 18:02:38 +0100 Subject: [PATCH 28/61] New alg, (re/ab)using ControlWorkItem, but need ParentSourceMap --- hugr-passes/src/non_local.rs | 175 +++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 387516e9ad..eedd76c41c 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -155,6 +155,181 @@ fn build_needs_sources_map( } // Transformation: adding extra ports, and wiring them up =============================== +impl BBNeedsSourcesMap { + fn thread_node( + &self, + hugr: &mut impl HugrMut, + node: N, + locals: &HashMap, Wire>, + ) { + if self.get(node).next().is_none() { + // No edges incoming into this subtree, but there could still be nonlocal edges internal to it + for ch in hugr.children(node).collect::>() { + self.thread_node(hugr, ch, &HashMap::new()) + } + return; + } + + let sources: Vec<(Wire, Type)> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); + let src_wires: Vec> = sources.iter().map(|(w, _)| *w).collect(); + + // `match` must deal with everything inside the node, and update the signature (per OpType) + let start_new_port_index = match hugr.optype_mut(node) { + OpType::DFG(dfg) => { + let ins = dfg.signature.input.to_mut(); + let start_new_port_index = ins.len(); + ins.extend(sources.iter().map(|(_, t)| t.clone())); + + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources); + start_new_port_index + } + OpType::Conditional(cond) => { + let start_new_port_index = cond.signature().input.len(); + cond.other_inputs + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + + self.thread_conditional(hugr, node, sources); + start_new_port_index + } + OpType::TailLoop(tail_op) => { + vec_prepend( + tail_op.just_inputs.to_mut(), + sources.iter().map(|(_, t)| t.clone()), + ); + self.thread_tailloop(hugr, node, sources); + 0 + } + OpType::CFG(cfg) => { + vec_prepend( + cfg.signature.input.to_mut(), + sources.iter().map(|(_, t)| t.clone()), + ); + assert_eq!( + self.get(node).collect::>(), + self.get(hugr.children(node).next().unwrap()) + .collect::>() + ); // Entry node + for bb in hugr.children(node).collect::>() { + self.thread_bb(hugr, bb); + } + 0 + } + _ => panic!( + "All containers handled except Module/FuncDefn or root Case/DFB, which should not have incoming nonlocal edges" + ), + }; + + let new_dfg_ports = hugr.insert_ports( + node, + Direction::Incoming, + start_new_port_index, + src_wires.len(), + ); + let local_srcs = src_wires.into_iter().map(|w| *locals.get(&w).unwrap_or(&w)); + for (w, tgt_port) in local_srcs.zip_eq(new_dfg_ports) { + assert_eq!(hugr.get_parent(w.node()), hugr.get_parent(node)); + hugr.connect(w.node(), w.source(), node, tgt_port) + } + } + + fn thread_dataflow_parent( + &self, + hugr: &mut impl HugrMut, + node: N, + start_new_port_index: usize, + srcs: Vec<(Wire, Type)>, + ) -> HashMap, Wire> { + let nlocals = if srcs.is_empty() { + HashMap::new() + } else { + let (srcs, tys): (Vec<_>, Vec) = srcs.into_iter().unzip(); + let [inp, _] = hugr.get_io(node).unwrap(); + let OpType::Input(in_op) = hugr.optype_mut(inp) else { + panic!("Expected Input node") + }; + vec_insert(in_op.types.to_mut(), tys, start_new_port_index); + let new_outports = + hugr.insert_ports(inp, Direction::Outgoing, start_new_port_index, srcs.len()); + + srcs.into_iter() + .zip_eq(new_outports) + .map(|(w, p)| (w, Wire::new(inp, p))) + .collect() + }; + for ch in hugr.children(node).collect::>() { + self.thread_node(hugr, ch, &nlocals); + } + nlocals + } + + fn thread_conditional( + &self, + hugr: &mut impl HugrMut, + node: N, + srcs: Vec<(Wire, Type)>, + ) { + for case in hugr.children(node).collect::>() { + let OpType::Case(case_op) = hugr.optype_mut(case) else { + continue; + }; + let ins = case_op.signature.input.to_mut(); + let start_case_port_index = ins.len(); + ins.extend(srcs.iter().map(|(_, t)| t.clone())); + self.thread_dataflow_parent(hugr, case, start_case_port_index, srcs.clone()); + } + } + + fn thread_tailloop( + &self, + hugr: &mut impl HugrMut, + node: N, + srcs: Vec<(Wire, Type)>, + ) { + let [_, o] = hugr.get_io(node).unwrap(); + let new_sum_row_prefixes = { + let mut v = vec![vec![]; 2]; + v[TailLoop::CONTINUE_TAG].extend(srcs.iter().map(|(w, _)| w)); + v + }; + ControlWorkItem { + output_node: o, + variant_source_prefixes: new_sum_row_prefixes, + } + .go(&mut hugr, psm); + self.thread_dataflow_parent(hugr, node, 0, srcs); + } + + fn thread_bb(&self, hugr: &mut impl HugrMut, node: N) { + let locals = self.thread_dataflow_parent( + hugr, + node, + 0, + self.get(node).map(|(w, t)| (*w, t.clone())).collect(), + ); + let [_, output_node] = hugr.get_io(node).unwrap(); + let variant_source_prefixes = hugr + .output_neighbours(node) + .map(|succ| { + // The wires required for each successor block, should be available in the predecessor + self.get(succ) + .map(|(w, _)| { + if hugr.get_parent(w.node()) == Some(node) { + *w + } else { + *locals.get(w).unwrap() + } + }) + .collect() + }) + .collect(); + ControlWorkItem { + output_node, + variant_source_prefixes, + } + .go(&mut hugr, psm) + } +} #[derive(derive_more::Error, derive_more::From, derive_more::Display, Debug, PartialEq)] #[non_exhaustive] From bcadbca33918613f9ad233ab26ad3c86012da7eb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 18:07:37 +0100 Subject: [PATCH 29/61] TEMP change ControlWorkItem::go to take impl Into> --- hugr-passes/src/non_local.rs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index eedd76c41c..64818a92e8 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -296,7 +296,7 @@ impl BBNeedsSourcesMap { output_node: o, variant_source_prefixes: new_sum_row_prefixes, } - .go(&mut hugr, psm); + .go(hugr, None); self.thread_dataflow_parent(hugr, node, 0, srcs); } @@ -327,7 +327,7 @@ impl BBNeedsSourcesMap { output_node, variant_source_prefixes, } - .go(&mut hugr, psm) + .go(hugr, None) } } @@ -452,7 +452,14 @@ struct ControlWorkItem { } impl ControlWorkItem { - fn go(self, hugr: &mut impl HugrMut, psm: &ParentSourceMap) { + fn go<'a>( + self, + hugr: &mut impl HugrMut, + psm: impl Into>>, + ) where + N: 'a, + { + let psm = psm.into(); let parent = hugr.get_parent(self.output_node).unwrap(); let Some(mut output) = hugr.get_optype(self.output_node).as_output().cloned() else { panic!("impossible") @@ -470,7 +477,17 @@ impl ControlWorkItem { }; let mut type_for_source = |source: &Wire| { - let (w, t) = psm.get_source_in_parent(parent, *source, &hugr); + let (w, t) = match psm { + Some(psm) => psm.get_source_in_parent(parent, *source, &hugr), + None => ( + *source, + hugr.signature(source.node()) + .unwrap() + .out_port_type(source.source()) + .unwrap() + .clone(), + ), + }; let replaced = needed_sources.insert(*source, (w, t.clone())); debug_assert!(!replaced.is_some_and(|x| x != (w, t.clone()))); t From 2b7207025e658dc7658db68b1d55f8de8413f64a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 18:45:12 +0100 Subject: [PATCH 30/61] Use new code; some fixes; still need to update BB sum_rows --- hugr-passes/src/non_local.rs | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 64818a92e8..1aafb1f58e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -211,7 +211,9 @@ impl BBNeedsSourcesMap { .collect::>() ); // Entry node for bb in hugr.children(node).collect::>() { - self.thread_bb(hugr, bb); + if hugr.get_optype(bb).is_dataflow_block() { + self.thread_bb(hugr, bb); + } } 0 } @@ -258,6 +260,15 @@ impl BBNeedsSourcesMap { .collect() }; for ch in hugr.children(node).collect::>() { + for (inp, _) in hugr.in_value_types(ch).collect::>() { + if let Some((src_n, src_p)) = hugr.single_linked_output(ch, inp) { + if hugr.get_parent(src_n) != Some(node) { + hugr.disconnect(ch, inp); + let new_p = nlocals.get(&Wire::new(src_n, src_p)).unwrap(); + hugr.connect(new_p.node(), new_p.source(), ch, inp); + } + } + } self.thread_node(hugr, ch, &nlocals); } nlocals @@ -307,7 +318,6 @@ impl BBNeedsSourcesMap { 0, self.get(node).map(|(w, t)| (*w, t.clone())).collect(), ); - let [_, output_node] = hugr.get_io(node).unwrap(); let variant_source_prefixes = hugr .output_neighbours(node) .map(|succ| { @@ -323,6 +333,7 @@ impl BBNeedsSourcesMap { .collect() }) .collect(); + let [_, output_node] = hugr.get_io(node).unwrap(); ControlWorkItem { output_node, variant_source_prefixes, @@ -856,25 +867,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), NonLocalEdg } } - // Here we mutate the HUGR; adding ports to parent nodes and their Input nodes. - // The result is: - // * parent_source_map: A map from parent and source to the wire that should substitute for that source in that parent. - // * worklist: a list of workitems. Each should be fulfilled by connecting the source, substituted through parent_source_map, to the target. - // * control_worklist: A list of control ports (i.e. 0th output port of DataflowBlock or TailLoop) that must be rewired. - let (parent_source_map, worklist, control_worklist) = { - let mut worklist = nonlocal_edges_map.into_values().collect_vec(); - let (wl, psm, control_worklist) = thread_sources(hugr, &bb_needs_sources_map); - worklist.extend(wl); - (psm, worklist, control_worklist) - }; - - for wi in worklist { - wi.go(hugr, &parent_source_map) - } - - for cwi in control_worklist { - cwi.go(hugr, &parent_source_map) - } + bb_needs_sources_map.thread_node(hugr, hugr.entrypoint(), &HashMap::new()); Ok(()) } From c1b07b8b57236742f5d604812062422f14d4d2f1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 18:52:37 +0100 Subject: [PATCH 31/61] Remove ParentSourceMap, WorkItem (undoing the TEMP) --- hugr-passes/src/non_local.rs | 357 +---------------------------------- 1 file changed, 7 insertions(+), 350 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 1aafb1f58e..921dd104bd 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -307,7 +307,7 @@ impl BBNeedsSourcesMap { output_node: o, variant_source_prefixes: new_sum_row_prefixes, } - .go(hugr, None); + .go(hugr); self.thread_dataflow_parent(hugr, node, 0, srcs); } @@ -338,7 +338,7 @@ impl BBNeedsSourcesMap { output_node, variant_source_prefixes, } - .go(hugr, None) + .go(hugr) } } @@ -369,93 +369,6 @@ struct WorkItem { ty: Type, } -impl WorkItem { - pub fn go(self, hugr: &mut impl HugrMut, parent_source_map: &ParentSourceMap) { - let parent = hugr.get_parent(self.target.0).unwrap(); - let source = if hugr.get_parent(self.source.node()).unwrap() == parent { - self.source - } else { - parent_source_map - .get_source_in_parent(parent, self.source, &hugr) - .0 - }; - debug_assert_eq!( - hugr.get_parent(source.node()), - hugr.get_parent(self.target.0) - ); - hugr.disconnect(self.target.0, self.target.1); - hugr.connect(source.node(), source.source(), self.target.0, self.target.1); - } -} - -#[derive(Clone, Debug)] -struct ParentSourceMap(BTreeMap, (Wire, Type)>>); - -impl Default for ParentSourceMap { - fn default() -> Self { - Self(BTreeMap::default()) - } -} - -impl ParentSourceMap { - fn insert_sources_in_parent( - &mut self, - parent: N, - sources: impl IntoIterator, Wire, Type)>, - ) { - debug_assert!(!self.0.contains_key(&parent)); - self.0 - .entry(parent) - .or_default() - .extend(sources.into_iter().map(|(s, p, t)| (s, (p, t)))); - } - - fn get_source_in_parent( - &self, - parent: N, - source: Wire, - hugr: impl HugrView, - ) -> (Wire, Type) { - let r @ (w, _) = self - .0 - .get(&parent) - .and_then(|m| m.get(&source).cloned()) - .unwrap(); - debug_assert_eq!(hugr.get_parent(w.node()).unwrap(), parent); - r - } - - fn thread_dataflow_parent( - &mut self, - hugr: &mut impl HugrMut, - parent: N, - start_port_index: usize, - sources: impl IntoIterator, Type)>, - ) { - let (source_wires, source_types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); - let input_wires = { - let [input_n, _] = hugr.get_io(parent).unwrap(); - let Some(mut input) = hugr.get_optype(input_n).as_input().cloned() else { - panic!("impossible") - }; - vec_insert(input.types.to_mut(), source_types.clone(), start_port_index); - hugr.replace_op(input_n, input); - hugr.insert_ports( - input_n, - Direction::Outgoing, - start_port_index, - source_wires.len(), - ) - .map(move |new_port| Wire::new(input_n, new_port)) - .collect_vec() - }; - self.insert_sources_in_parent( - parent, - itertools::izip!(source_wires, input_wires, source_types), - ); - } -} - #[derive(Clone, Debug)] struct ControlWorkItem { output_node: N, // Output node of CFG / TailLoop @@ -463,14 +376,7 @@ struct ControlWorkItem { } impl ControlWorkItem { - fn go<'a>( - self, - hugr: &mut impl HugrMut, - psm: impl Into>>, - ) where - N: 'a, - { - let psm = psm.into(); + fn go(self, hugr: &mut impl HugrMut) { let parent = hugr.get_parent(self.output_node).unwrap(); let Some(mut output) = hugr.get_optype(self.output_node).as_output().cloned() else { panic!("impossible") @@ -488,19 +394,13 @@ impl ControlWorkItem { }; let mut type_for_source = |source: &Wire| { - let (w, t) = match psm { - Some(psm) => psm.get_source_in_parent(parent, *source, &hugr), - None => ( - *source, - hugr.signature(source.node()) + let t = hugr.signature(source.node()) .unwrap() .out_port_type(source.source()) .unwrap() - .clone(), - ), - }; - let replaced = needed_sources.insert(*source, (w, t.clone())); - debug_assert!(!replaced.is_some_and(|x| x != (w, t.clone()))); + .clone(); + let replaced = needed_sources.insert(*source, (*source, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (*source, t.clone()))); t }; let old_sum_rows: Vec = sum_type @@ -567,249 +467,6 @@ impl ControlWorkItem { } } -#[derive(Clone, Debug)] -struct ThreadState<'a, N: HugrNode> { - parent_source_map: ParentSourceMap, - needs: &'a BBNeedsSourcesMap, - worklist: Vec>, - control_worklist: Vec>, -} - -impl<'a, N: HugrNode> ThreadState<'a, N> { - delegate! { - to self.parent_source_map { - fn thread_dataflow_parent( - &mut self, - hugr: &mut impl HugrMut, - parent: N, - start_port_index: usize, - sources: impl IntoIterator, Type)>, - ); - } - } - - fn new(bbnsm: &'a BBNeedsSourcesMap) -> Self { - Self { - parent_source_map: ParentSourceMap::default(), - needs: bbnsm, - worklist: vec![], - control_worklist: vec![], - } - } - - fn do_dataflow_block( - &mut self, - hugr: &mut impl HugrMut, - node: N, - sources: Vec<(Wire, Type)>, - ) { - let types = sources.iter().map(|x| x.1.clone()).collect_vec(); - let new_sum_row_prefixes = { - let mut this_dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); - let mut nsrp = vec![vec![]; this_dfb.sum_rows.len()]; - vec_prepend(this_dfb.inputs.to_mut(), types.clone()); - - for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { - let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); - hugr.get_optype(succ_n) - .is_dataflow_block() - .then_some((out_p.index(), succ_n)) - }) { - let succ_needs_source_indices = self - .needs - .get(succ_n) - .map(|(w, _)| sources.iter().find_position(|(x, _)| x == w).unwrap().0) - .collect_vec(); - let succ_needs_tys = succ_needs_source_indices - .iter() - .copied() - .map(|x| sources[x].1.clone()) - .collect_vec(); - vec_prepend(this_dfb.sum_rows[this_p].to_mut(), succ_needs_tys); - nsrp[this_p] = succ_needs_source_indices; - } - hugr.replace_op(node, this_dfb); - nsrp - }; - - self.thread_dataflow_parent(hugr, node, 0, sources.clone()); - - let [_, o] = hugr.get_io(node).unwrap(); - self.control_worklist.push(ControlWorkItem { - output_node: o, - variant_source_prefixes: new_sum_row_prefixes - .into_iter() - .map(|v| v.into_iter().map(|i| sources[i].0).collect_vec()) - .collect_vec(), - }); - } - - fn do_cfg( - &mut self, - hugr: &mut impl HugrMut, - node: N, - sources: Vec<(Wire, Type)>, - ) { - let types = sources.iter().map(|x| x.1.clone()).collect_vec(); - { - let mut cfg = hugr.get_optype(node).as_cfg().unwrap().clone(); - vec_insert(cfg.signature.input.to_mut(), types, 0); - hugr.replace_op(node, cfg); - } - let new_cond_ports = hugr - .insert_ports(node, Direction::Incoming, 0, sources.len()) - .map_into(); - self.worklist - .extend(mk_workitems(node, sources, new_cond_ports)) - } - - fn do_dfg( - &mut self, - hugr: &mut impl HugrMut, - node: N, - sources: Vec<(Wire, Type)>, - ) { - let mut dfg = hugr.get_optype(node).as_dfg().unwrap().clone(); - let start_new_port_index = dfg.signature.input().len(); - let new_dfg_ports = hugr - .insert_ports( - node, - Direction::Incoming, - start_new_port_index, - sources.len(), - ) - .map_into(); - dfg.signature - .input - .to_mut() - .extend(sources.iter().map(|x| x.1.clone())); - hugr.replace_op(node, dfg); - self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); - self.worklist - .extend(mk_workitems(node, sources, new_dfg_ports)); - } - - fn do_conditional( - &mut self, - hugr: &mut impl HugrMut, - node: N, - sources: Vec<(Wire, Type)>, - ) { - let mut cond = hugr.get_optype(node).as_conditional().unwrap().clone(); - let start_new_port_index = cond.signature().input().len(); - cond.other_inputs - .to_mut() - .extend(sources.iter().map(|x| x.1.clone())); - hugr.replace_op(node, cond); - let new_cond_ports = hugr - .insert_ports( - node, - Direction::Incoming, - start_new_port_index, - sources.len(), - ) - .map_into(); - self.worklist - .extend(mk_workitems(node, sources, new_cond_ports)) - } - - fn do_case( - &mut self, - hugr: &mut impl HugrMut, - node: N, - sources: Vec<(Wire, Type)>, - ) { - let mut case = hugr.get_optype(node).as_case().unwrap().clone(); - let start_case_port_index = case.signature.input().len(); - case.signature - .input - .to_mut() - .extend(sources.iter().map(|x| x.1.clone())); - hugr.replace_op(node, case); - self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); - } - - fn do_tailloop( - &mut self, - hugr: &mut impl HugrMut, - node: N, - sources: Vec<(Wire, Type)>, - ) { - let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); - let types = sources.iter().map(|x| x.1.clone()).collect_vec(); - { - vec_prepend(tailloop.just_inputs.to_mut(), types.clone()); - hugr.replace_op(node, tailloop); - } - let tailloop_ports = hugr - .insert_ports(node, Direction::Incoming, 0, sources.len()) - .map_into(); - - self.thread_dataflow_parent(hugr, node, 0, sources.clone()); - - let [_, o] = hugr.get_io(node).unwrap(); - let new_sum_row_prefixes = { - let mut v = vec![vec![]; 2]; - v[TailLoop::CONTINUE_TAG].extend(sources.iter().map(|x| x.0)); - v - }; - self.control_worklist.push(ControlWorkItem { - output_node: o, - variant_source_prefixes: new_sum_row_prefixes, - }); - self.worklist - .extend(mk_workitems(node, sources, tailloop_ports)) - } - - fn finish( - self, - _hugr: &mut impl HugrMut, - ) -> ( - Vec>, - ParentSourceMap, - Vec>, - ) { - (self.worklist, self.parent_source_map, self.control_worklist) - } -} - -fn thread_sources( - hugr: &mut impl HugrMut, - bb_needs_sources_map: &BBNeedsSourcesMap, -) -> ( - Vec>, - ParentSourceMap, - Vec>, -) { - let mut state = ThreadState::new(bb_needs_sources_map); - for (&bb, sources) in bb_needs_sources_map.0.iter() { - let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); - match hugr.get_optype(bb).clone() { - OpType::DFG(_) => state.do_dfg(hugr, bb, sources), - OpType::Conditional(_) => state.do_conditional(hugr, bb, sources), - OpType::Case(_) => state.do_case(hugr, bb, sources), - OpType::TailLoop(_) => state.do_tailloop(hugr, bb, sources), - OpType::DataflowBlock(_) => state.do_dataflow_block(hugr, bb, sources), - OpType::CFG(_) => state.do_cfg(hugr, bb, sources), - _ => panic!("impossible"), - } - } - - state.finish(hugr) -} - -fn mk_workitems( - node: N, - sources: impl IntoIterator, Type)>, - ports: impl IntoIterator, -) -> impl Iterator> { - itertools::izip!(sources, ports).map(move |((source, ty), p)| WorkItem { - source, - target: (node, p), - ty, - }) -} - pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), NonLocalEdgesError> { // First we collect all the non-local edges in the graph. We associate them to a WorkItem, which tracks: // * the source of the non-local edge From 1518d4ef943507bb37196b9f62793681af3d3362 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 19:00:38 +0100 Subject: [PATCH 32/61] Add Types into ControlWorkItem --- hugr-passes/src/non_local.rs | 37 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 921dd104bd..db9119632d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -300,7 +300,7 @@ impl BBNeedsSourcesMap { let [_, o] = hugr.get_io(node).unwrap(); let new_sum_row_prefixes = { let mut v = vec![vec![]; 2]; - v[TailLoop::CONTINUE_TAG].extend(srcs.iter().map(|(w, _)| w)); + v[TailLoop::CONTINUE_TAG] = srcs.clone(); v }; ControlWorkItem { @@ -323,12 +323,15 @@ impl BBNeedsSourcesMap { .map(|succ| { // The wires required for each successor block, should be available in the predecessor self.get(succ) - .map(|(w, _)| { - if hugr.get_parent(w.node()) == Some(node) { - *w - } else { - *locals.get(w).unwrap() - } + .map(|(w, ty)| { + ( + if hugr.get_parent(w.node()) == Some(node) { + *w + } else { + *locals.get(w).unwrap() + }, + ty.clone(), + ) }) .collect() }) @@ -371,8 +374,8 @@ struct WorkItem { #[derive(Clone, Debug)] struct ControlWorkItem { - output_node: N, // Output node of CFG / TailLoop - variant_source_prefixes: Vec>>, // prefixes to each element of Sum type + output_node: N, // Output node of CFG / TailLoop + variant_source_prefixes: Vec, Type)>>, // prefixes to each element of Sum type } impl ControlWorkItem { @@ -393,15 +396,11 @@ impl ControlWorkItem { panic!("impossible") }; - let mut type_for_source = |source: &Wire| { - let t = hugr.signature(source.node()) - .unwrap() - .out_port_type(source.source()) - .unwrap() - .clone(); - let replaced = needed_sources.insert(*source, (*source, t.clone())); - debug_assert!(!replaced.is_some_and(|x| x != (*source, t.clone()))); - t + let mut type_for_source = |source: &(Wire, Type)| { + let (w, t) = source; + let replaced = needed_sources.insert(*w, (*w, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (*w, t.clone()))); + t.clone() }; let old_sum_rows: Vec = sum_type .variants() @@ -434,7 +433,7 @@ impl ControlWorkItem { let case_inputs = case.input_wires().collect_vec(); let mut args = new_sources .into_iter() - .map(|s| { + .map(|(s, _ty)| { case_inputs[old_sum_rows[i].len() + needed_sources .iter() From e019f046e1a51841e0ad4421bc7f935da2358490 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 19:15:29 +0100 Subject: [PATCH 33/61] Update DFB inputs and sum_rows - tests now passing --- hugr-passes/src/non_local.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index db9119632d..edb239510b 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -312,13 +312,20 @@ impl BBNeedsSourcesMap { } fn thread_bb(&self, hugr: &mut impl HugrMut, node: N) { + let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { + panic!("Expected dataflow block") + }; + vec_prepend( + this_dfb.inputs.to_mut(), + self.get(node).map(|(_, t)| t.clone()), + ); let locals = self.thread_dataflow_parent( hugr, node, 0, self.get(node).map(|(w, t)| (*w, t.clone())).collect(), ); - let variant_source_prefixes = hugr + let variant_source_prefixes: Vec, Type)>> = hugr .output_neighbours(node) .map(|succ| { // The wires required for each successor block, should be available in the predecessor @@ -336,6 +343,18 @@ impl BBNeedsSourcesMap { .collect() }) .collect(); + let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { + panic!("It worked earlier!") + }; + for (source_prefix, sum_row) in variant_source_prefixes + .iter() + .zip_eq(this_dfb.sum_rows.iter_mut()) + { + vec_prepend( + sum_row.to_mut(), + source_prefix.iter().map(|(_, t)| t.clone()), + ); + } let [_, output_node] = hugr.get_io(node).unwrap(); ControlWorkItem { output_node, From a41ae95a90c729a05c3262f963a5c370152876b2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 19 May 2025 19:22:31 +0100 Subject: [PATCH 34/61] Add just_types helper --- hugr-passes/src/non_local.rs | 45 +++++++++++------------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index edb239510b..445f218e44 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -178,33 +178,25 @@ impl BBNeedsSourcesMap { OpType::DFG(dfg) => { let ins = dfg.signature.input.to_mut(); let start_new_port_index = ins.len(); - ins.extend(sources.iter().map(|(_, t)| t.clone())); + ins.extend(just_types(&sources)); self.thread_dataflow_parent(hugr, node, start_new_port_index, sources); start_new_port_index } OpType::Conditional(cond) => { let start_new_port_index = cond.signature().input.len(); - cond.other_inputs - .to_mut() - .extend(sources.iter().map(|x| x.1.clone())); + cond.other_inputs.to_mut().extend(just_types(&sources)); self.thread_conditional(hugr, node, sources); start_new_port_index } OpType::TailLoop(tail_op) => { - vec_prepend( - tail_op.just_inputs.to_mut(), - sources.iter().map(|(_, t)| t.clone()), - ); + vec_prepend(tail_op.just_inputs.to_mut(), just_types(&sources)); self.thread_tailloop(hugr, node, sources); 0 } OpType::CFG(cfg) => { - vec_prepend( - cfg.signature.input.to_mut(), - sources.iter().map(|(_, t)| t.clone()), - ); + vec_prepend(cfg.signature.input.to_mut(), just_types(&sources)); assert_eq!( self.get(node).collect::>(), self.get(hugr.children(node).next().unwrap()) @@ -286,7 +278,7 @@ impl BBNeedsSourcesMap { }; let ins = case_op.signature.input.to_mut(); let start_case_port_index = ins.len(); - ins.extend(srcs.iter().map(|(_, t)| t.clone())); + ins.extend(just_types(&srcs)); self.thread_dataflow_parent(hugr, case, start_case_port_index, srcs.clone()); } } @@ -315,16 +307,9 @@ impl BBNeedsSourcesMap { let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { panic!("Expected dataflow block") }; - vec_prepend( - this_dfb.inputs.to_mut(), - self.get(node).map(|(_, t)| t.clone()), - ); - let locals = self.thread_dataflow_parent( - hugr, - node, - 0, - self.get(node).map(|(w, t)| (*w, t.clone())).collect(), - ); + let my_inputs: Vec<_> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); + vec_prepend(this_dfb.inputs.to_mut(), just_types(&my_inputs)); + let locals = self.thread_dataflow_parent(hugr, node, 0, my_inputs); let variant_source_prefixes: Vec, Type)>> = hugr .output_neighbours(node) .map(|succ| { @@ -350,10 +335,7 @@ impl BBNeedsSourcesMap { .iter() .zip_eq(this_dfb.sum_rows.iter_mut()) { - vec_prepend( - sum_row.to_mut(), - source_prefix.iter().map(|(_, t)| t.clone()), - ); + vec_prepend(sum_row.to_mut(), just_types(source_prefix)); } let [_, output_node] = hugr.get_io(node).unwrap(); ControlWorkItem { @@ -364,6 +346,10 @@ impl BBNeedsSourcesMap { } } +fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { + v.into_iter().map(|(_, t)| t.clone()) +} + #[derive(derive_more::Error, derive_more::From, derive_more::Display, Debug, PartialEq)] #[non_exhaustive] pub enum NonLocalEdgesError { @@ -440,10 +426,7 @@ impl ControlWorkItem { let new_control_type = Type::new_sum(new_sum_rows.clone()); let mut cond = ConditionalBuilder::new( old_sum_rows.clone(), - needed_sources - .values() - .map(|(_, t)| t.clone()) - .collect_vec(), + just_types(needed_sources.values()).collect_vec(), new_control_type.clone(), ) .unwrap(); From 68f30204c5446c1a394992942af52256d0dbda62 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 09:04:21 +0100 Subject: [PATCH 35/61] Renaming, add struct + impl ComposablePass --- hugr-passes/src/lib.rs | 4 +- hugr-passes/src/{non_local.rs => localize.rs} | 60 ++++++++----------- 2 files changed, 26 insertions(+), 38 deletions(-) rename hugr-passes/src/{non_local.rs => localize.rs} (96%) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index c82fc5abe6..c23da37a42 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -22,9 +22,9 @@ pub mod untuple; pub use monomorphize::{MonomorphizePass, mangle_name, monomorphize}; pub mod replace_types; pub use replace_types::ReplaceTypes; +pub mod localize; pub mod nest_cfgs; -pub mod non_local; pub use force_order::{force_order, force_order_by_key}; +pub use localize::{ensure_no_nonlocal_edges, nonlocal_edges}; pub use lower::{lower_ops, replace_many_ops}; -pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; pub use untuple::UntuplePass; diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/localize.rs similarity index 96% rename from hugr-passes/src/non_local.rs rename to hugr-passes/src/localize.rs index 445f218e44..ec16152b38 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/localize.rs @@ -1,5 +1,5 @@ -//! This module provides functions for inspecting and modifying the nature of -//! non local edges in a Hugr. +//! This module provides functions for finding non-local edges +//! in a Hugr and converting them to local edges. use delegate::delegate; use std::{ collections::{BTreeMap, HashMap}, @@ -10,41 +10,29 @@ use hugr_core::{HugrView, IncomingPort, core::HugrNode}; use itertools::{Either, Itertools as _}; use hugr_core::{ - Direction, PortIndex, Wire, + Direction, Wire, builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, hugr::{HugrError, hugrmut::HugrMut}, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, }; -// use crate::validation::{ValidatePassError, ValidationLevel}; - -/// TODO docs -// #[derive(Debug, Clone, Default)] -// pub struct UnNonLocalPass { -// validation: ValidationLevel, -// } - -// impl UnNonLocalPass { -// /// Sets the validation level used before and after the pass is run. -// pub fn validation_level(mut self, level: ValidationLevel) -> Self { -// self.validation = level; -// self -// } - -// /// Run the Monomorphization pass. -// fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { -// let root = hugr.root(); -// remove_nonlocal_edges(hugr, root)?; -// Ok(()) -// } - -// /// Run the pass using specified configuration. -// pub fn run(&self, hugr: &mut H) -> Result<(), NonLocalEdgesError> { -// self.validation -// .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) -// } -// } +use crate::ComposablePass; + +/// [ComposablePass] that converts all non-local edges in a Hugr +/// into local ones, by inserting extra inputs to container nodes +/// and extra outports to Input nodes. +struct LocalizeEdges; + +impl ComposablePass for LocalizeEdges { + type Error = NonLocalEdgesError; + + type Result = (); + + fn run(&self, hugr: &mut H) -> Result { + remove_nonlocal_edges(hugr) + } +} /// Returns an iterator over all non local edges in a Hugr. /// @@ -62,7 +50,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator(BTreeMap, Type>>); impl Default for BBNeedsSourcesMap { @@ -610,7 +598,7 @@ mod test { } #[test] - fn unnonlocal_dfg() { + fn localize_dfg() { let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [w0] = outer.input_wires_arr(); @@ -629,7 +617,7 @@ mod test { } #[test] - fn unnonlocal_tailloop() { + fn localize_tailloop() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(vec![ @@ -672,7 +660,7 @@ mod test { } #[test] - fn unnonlocal_conditional() { + fn localize_conditional() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); @@ -724,7 +712,7 @@ mod test { } #[test] - fn unnonlocal_cfg() { + fn localize_cfg() { // Cfg consists of 4 dataflow blocks and an exit block // // The 4 dataflow blocks form a diamond, and the bottom block branches From 4cc2948355f7e8cf8b920bff3d4ce425b2765b20 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 09:41:00 +0100 Subject: [PATCH 36/61] Undo move --- hugr-passes/src/lib.rs | 4 ++-- hugr-passes/src/{localize.rs => non_local.rs} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename hugr-passes/src/{localize.rs => non_local.rs} (100%) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index c23da37a42..c82fc5abe6 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -22,9 +22,9 @@ pub mod untuple; pub use monomorphize::{MonomorphizePass, mangle_name, monomorphize}; pub mod replace_types; pub use replace_types::ReplaceTypes; -pub mod localize; pub mod nest_cfgs; +pub mod non_local; pub use force_order::{force_order, force_order_by_key}; -pub use localize::{ensure_no_nonlocal_edges, nonlocal_edges}; pub use lower::{lower_ops, replace_many_ops}; +pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; pub use untuple::UntuplePass; diff --git a/hugr-passes/src/localize.rs b/hugr-passes/src/non_local.rs similarity index 100% rename from hugr-passes/src/localize.rs rename to hugr-passes/src/non_local.rs From c4000e8cb29c9f6a58eb9eb063a05728c03102d7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 10:00:58 +0100 Subject: [PATCH 37/61] separate errors, deprecate NonLocalEdgesError --- hugr-passes/src/non_local.rs | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index ec16152b38..3ae96d613d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -25,7 +25,7 @@ use crate::ComposablePass; struct LocalizeEdges; impl ComposablePass for LocalizeEdges { - type Error = NonLocalEdgesError; + type Error = LocalizeEdgesError; type Result = (); @@ -338,26 +338,36 @@ fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Ite v.into_iter().map(|(_, t)| t.clone()) } -#[derive(derive_more::Error, derive_more::From, derive_more::Display, Debug, PartialEq)] +#[deprecated(note = "Use FindNonLocalEdgesError")] +pub type NonLocalEdgesError = FindNonLocalEdgesError; + +/// An error from [ensure_no_nonlocal_edges] +#[derive(Clone, derive_more::Error, derive_more::Display, Debug, PartialEq, Eq)] #[non_exhaustive] -pub enum NonLocalEdgesError { +pub enum FindNonLocalEdgesError { #[display("Found {} nonlocal edges", _0.len())] #[error(ignore)] Edges(Vec<(N, IncomingPort)>), - #[from] - HugrError(HugrError), } /// Verifies that there are no non local value edges in the Hugr. -pub fn ensure_no_nonlocal_edges(hugr: &H) -> Result<(), NonLocalEdgesError> { +pub fn ensure_no_nonlocal_edges( + hugr: &H, +) -> Result<(), FindNonLocalEdgesError> { let non_local_edges: Vec<_> = nonlocal_edges(hugr).collect_vec(); if non_local_edges.is_empty() { Ok(()) } else { - Err(NonLocalEdgesError::Edges(non_local_edges))? + Err(FindNonLocalEdgesError::Edges(non_local_edges))? } } +#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] +#[non_exhaustive] +pub enum LocalizeEdgesError { + HugrError(#[from] HugrError), +} + #[derive(Debug, Clone)] struct WorkItem { source: Wire, @@ -456,7 +466,7 @@ impl ControlWorkItem { } } -pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), NonLocalEdgesError> { +pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { // First we collect all the non-local edges in the graph. We associate them to a WorkItem, which tracks: // * the source of the non-local edge // * the target of the non-local edge @@ -593,7 +603,7 @@ mod test { }; assert_eq!( ensure_no_nonlocal_edges(&hugr).unwrap_err(), - NonLocalEdgesError::Edges(vec![edge]) + FindNonLocalEdgesError::Edges(vec![edge]) ); } From 8608a5fa4e00542834f6e7407e5c7f1dbf1576af Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 10:09:42 +0100 Subject: [PATCH 38/61] Convert ControlWorkItem into a simple function --- hugr-passes/src/non_local.rs | 174 +++++++++++++++++------------------ 1 file changed, 82 insertions(+), 92 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 3ae96d613d..083d28238a 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -283,11 +283,7 @@ impl BBNeedsSourcesMap { v[TailLoop::CONTINUE_TAG] = srcs.clone(); v }; - ControlWorkItem { - output_node: o, - variant_source_prefixes: new_sum_row_prefixes, - } - .go(hugr); + add_control_prefixes(hugr, o, new_sum_row_prefixes); self.thread_dataflow_parent(hugr, node, 0, srcs); } @@ -326,11 +322,7 @@ impl BBNeedsSourcesMap { vec_prepend(sum_row.to_mut(), just_types(source_prefix)); } let [_, output_node] = hugr.get_io(node).unwrap(); - ControlWorkItem { - output_node, - variant_source_prefixes, - } - .go(hugr) + add_control_prefixes(hugr, output_node, variant_source_prefixes); } } @@ -375,95 +367,93 @@ struct WorkItem { ty: Type, } -#[derive(Clone, Debug)] -struct ControlWorkItem { - output_node: N, // Output node of CFG / TailLoop - variant_source_prefixes: Vec, Type)>>, // prefixes to each element of Sum type -} - -impl ControlWorkItem { - fn go(self, hugr: &mut impl HugrMut) { - let parent = hugr.get_parent(self.output_node).unwrap(); - let Some(mut output) = hugr.get_optype(self.output_node).as_output().cloned() else { +/// `variant_source_prefixes` are extra wires/types to prepend onto each variant +/// (must have one element per variant of control Sum) +fn add_control_prefixes( + hugr: &mut H, + output_node: H::Node, + variant_source_prefixes: Vec, Type)>>, +) { + let parent = hugr.get_parent(output_node).unwrap(); + let Some(mut output) = hugr.get_optype(output_node).as_output().cloned() else { + panic!("impossible") + }; + let mut needed_sources = BTreeMap::new(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = hugr + .get_optype(output_node) + .port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum() else { panic!("impossible") }; - let mut needed_sources = BTreeMap::new(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = hugr - .get_optype(self.output_node) - .port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; - let mut type_for_source = |source: &(Wire, Type)| { - let (w, t) = source; - let replaced = needed_sources.insert(*w, (*w, t.clone())); - debug_assert!(!replaced.is_some_and(|x| x != (*w, t.clone()))); - t.clone() - }; - let old_sum_rows: Vec = sum_type - .variants() - .map(|x| x.clone().try_into().unwrap()) + let mut type_for_source = |source: &(Wire, Type)| { + let (w, t) = source; + let replaced = needed_sources.insert(*w, (*w, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (*w, t.clone()))); + t.clone() + }; + let old_sum_rows: Vec = sum_type + .variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(variant_source_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_sources, old_tys)| { + new_sources + .iter() + .map(&mut type_for_source) + .chain(old_tys.iter().cloned()) + .collect_vec() + .into() + }) .collect_vec(); - let new_sum_rows: Vec = - itertools::zip_eq(self.variant_source_prefixes.iter(), old_sum_rows.iter()) - .map(|(new_sources, old_tys)| { - new_sources + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + just_types(needed_sources.values()).collect_vec(), + new_control_type.clone(), + ) + .unwrap(); + for (i, new_sources) in variant_source_prefixes.into_iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = new_sources + .into_iter() + .map(|(s, _ty)| { + case_inputs[old_sum_rows[i].len() + + needed_sources .iter() - .map(&mut type_for_source) - .chain(old_tys.iter().cloned()) - .collect_vec() - .into() - }) - .collect_vec(); - - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = ConditionalBuilder::new( - old_sum_rows.clone(), - just_types(needed_sources.values()).collect_vec(), - new_control_type.clone(), - ) - .unwrap(); - for (i, new_sources) in self.variant_source_prefixes.into_iter().enumerate() { - let mut case = cond.case_builder(i).unwrap(); - let case_inputs = case.input_wires().collect_vec(); - let mut args = new_sources - .into_iter() - .map(|(s, _ty)| { - case_inputs[old_sum_rows[i].len() - + needed_sources - .iter() - .find_position(|(w, _)| **w == s) - .unwrap() - .0] - }) - .collect_vec(); - args.extend(&case_inputs[..old_sum_rows[i].len()]); - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); - } - (cond.finish_hugr().unwrap(), new_control_type) - }; - let cond_node = hugr.insert_hugr(parent, cond).inserted_entrypoint; - let (old_output_source_node, old_output_source_port) = - hugr.single_linked_output(self.output_node, 0).unwrap(); - debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); - hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); - for (i, &(w, _)) in needed_sources.values().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); + .find_position(|(w, _)| **w == s) + .unwrap() + .0] + }) + .collect_vec(); + args.extend(&case_inputs[..old_sum_rows[i].len()]); + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); } - hugr.disconnect(self.output_node, IncomingPort::from(0)); - hugr.connect(cond_node, 0, self.output_node, 0); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(self.output_node, output); + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(parent, cond).inserted_entrypoint; + let (old_output_source_node, old_output_source_port) = + hugr.single_linked_output(output_node, 0).unwrap(); + debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); + hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); + for (i, &(w, _)) in needed_sources.values().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); } + hugr.disconnect(output_node, IncomingPort::from(0)); + hugr.connect(cond_node, 0, output_node, 0); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(output_node, output); } pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { From aee31a1016c4564f307494d6f95c8392ca6a45e1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 10:13:04 +0100 Subject: [PATCH 39/61] Use optype_mut --- hugr-passes/src/non_local.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 083d28238a..46149dbbb9 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -374,10 +374,8 @@ fn add_control_prefixes( output_node: H::Node, variant_source_prefixes: Vec, Type)>>, ) { + debug_assert!(hugr.get_optype(output_node).is_output()); // Just to fail fast let parent = hugr.get_parent(output_node).unwrap(); - let Some(mut output) = hugr.get_optype(output_node).as_output().cloned() else { - panic!("impossible") - }; let mut needed_sources = BTreeMap::new(); let (cond, new_control_type) = { let Some(EdgeKind::Value(control_type)) = hugr @@ -452,8 +450,10 @@ fn add_control_prefixes( } hugr.disconnect(output_node, IncomingPort::from(0)); hugr.connect(cond_node, 0, output_node, 0); + let OpType::Output(output) = hugr.optype_mut(output_node) else { + panic!("impossible") + }; output.types.to_mut()[0] = new_control_type; - hugr.replace_op(output_node, output); } pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { From 90e1029acdb6d1a52a02453941f0a6f67a937b86 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 13:10:28 +0100 Subject: [PATCH 40/61] Move BBNeedsSourcesMap(Builder) into submodule non_local/localize.rs --- hugr-passes/src/non_local.rs | 415 ++------------------------ hugr-passes/src/non_local/localize.rs | 404 +++++++++++++++++++++++++ 2 files changed, 421 insertions(+), 398 deletions(-) create mode 100644 hugr-passes/src/non_local/localize.rs diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 46149dbbb9..bf86019352 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,20 +1,17 @@ //! This module provides functions for finding non-local edges //! in a Hugr and converting them to local edges. -use delegate::delegate; -use std::{ - collections::{BTreeMap, HashMap}, - iter, mem, -}; +use std::collections::HashMap; use hugr_core::{HugrView, IncomingPort, core::HugrNode}; -use itertools::{Either, Itertools as _}; +use itertools::Itertools as _; + +mod localize; +use localize::{BBNeedsSourcesMap, BBNeedsSourcesMapBuilder}; use hugr_core::{ - Direction, Wire, - builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, + Wire, hugr::{HugrError, hugrmut::HugrMut}, - ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, - types::{EdgeKind, Type, TypeRow}, + types::{EdgeKind, Type}, }; use crate::ComposablePass; @@ -22,7 +19,13 @@ use crate::ComposablePass; /// [ComposablePass] that converts all non-local edges in a Hugr /// into local ones, by inserting extra inputs to container nodes /// and extra outports to Input nodes. -struct LocalizeEdges; +pub struct LocalizeEdges; + +#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] +#[non_exhaustive] +pub enum LocalizeEdgesError { + HugrError(#[from] HugrError), +} impl ComposablePass for LocalizeEdges { type Error = LocalizeEdgesError; @@ -47,87 +50,6 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator(BTreeMap, Type>>); - -impl Default for BBNeedsSourcesMap { - fn default() -> Self { - Self(BTreeMap::default()) - } -} - -impl BBNeedsSourcesMap { - fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { - self.0.entry(node).or_default().insert(source, ty).is_none() - } - - fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { - match self.0.get(&node) { - Some(x) => Either::Left(x.iter()), - None => Either::Right(iter::empty()), - } - } - - delegate! { - to self.0 { - fn keys(&self) -> impl Iterator; - } - } -} - -#[derive(Debug, Clone)] -struct BBNeedsSourcesMapBuilder { - hugr: H, - needs_sources: BBNeedsSourcesMap, -} - -impl BBNeedsSourcesMapBuilder { - fn new(hugr: H) -> Self { - Self { - hugr, - needs_sources: Default::default(), - } - } - - fn insert(&mut self, mut parent: H::Node, source: Wire, ty: Type) { - let source_parent = self.hugr.get_parent(source.node()).unwrap(); - while source_parent != parent { - if !self.needs_sources.insert(parent, source, ty.clone()) { - break; - } - if self.hugr.get_optype(parent).is_conditional() { - // One of these we must have just done on the previous iteration - for case in self.hugr.children(parent) { - // Full recursion unnecessary as we've just added parent: - self.needs_sources.insert(case, source, ty.clone()); - } - } - // this will panic if source_parent is not an ancestor of target - let parent_parent = self.hugr.get_parent(parent).unwrap(); - if self.hugr.get_optype(parent).is_dataflow_block() { - assert!(self.hugr.get_optype(parent_parent).is_cfg()); - for pred in self.hugr.input_neighbours(parent).collect::>() { - self.insert(pred, source, ty.clone()); - } - if Some(parent) != self.hugr.children(parent_parent).next() { - // Recursive calls on predecessors will have traced back to entry block - // (or source_parent itself if a dominating Basic Block) - break; - } - // We've just added to entry node - so must add to CFG as well - } - parent = parent_parent; - } - } - - fn finish(self) -> BBNeedsSourcesMap { - self.needs_sources - } -} - // Identify all required extra inputs (for both Dom and Ext edges) fn build_needs_sources_map( hugr: impl HugrView, @@ -142,194 +64,6 @@ fn build_needs_sources_map( bnsm.finish() } -// Transformation: adding extra ports, and wiring them up =============================== -impl BBNeedsSourcesMap { - fn thread_node( - &self, - hugr: &mut impl HugrMut, - node: N, - locals: &HashMap, Wire>, - ) { - if self.get(node).next().is_none() { - // No edges incoming into this subtree, but there could still be nonlocal edges internal to it - for ch in hugr.children(node).collect::>() { - self.thread_node(hugr, ch, &HashMap::new()) - } - return; - } - - let sources: Vec<(Wire, Type)> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); - let src_wires: Vec> = sources.iter().map(|(w, _)| *w).collect(); - - // `match` must deal with everything inside the node, and update the signature (per OpType) - let start_new_port_index = match hugr.optype_mut(node) { - OpType::DFG(dfg) => { - let ins = dfg.signature.input.to_mut(); - let start_new_port_index = ins.len(); - ins.extend(just_types(&sources)); - - self.thread_dataflow_parent(hugr, node, start_new_port_index, sources); - start_new_port_index - } - OpType::Conditional(cond) => { - let start_new_port_index = cond.signature().input.len(); - cond.other_inputs.to_mut().extend(just_types(&sources)); - - self.thread_conditional(hugr, node, sources); - start_new_port_index - } - OpType::TailLoop(tail_op) => { - vec_prepend(tail_op.just_inputs.to_mut(), just_types(&sources)); - self.thread_tailloop(hugr, node, sources); - 0 - } - OpType::CFG(cfg) => { - vec_prepend(cfg.signature.input.to_mut(), just_types(&sources)); - assert_eq!( - self.get(node).collect::>(), - self.get(hugr.children(node).next().unwrap()) - .collect::>() - ); // Entry node - for bb in hugr.children(node).collect::>() { - if hugr.get_optype(bb).is_dataflow_block() { - self.thread_bb(hugr, bb); - } - } - 0 - } - _ => panic!( - "All containers handled except Module/FuncDefn or root Case/DFB, which should not have incoming nonlocal edges" - ), - }; - - let new_dfg_ports = hugr.insert_ports( - node, - Direction::Incoming, - start_new_port_index, - src_wires.len(), - ); - let local_srcs = src_wires.into_iter().map(|w| *locals.get(&w).unwrap_or(&w)); - for (w, tgt_port) in local_srcs.zip_eq(new_dfg_ports) { - assert_eq!(hugr.get_parent(w.node()), hugr.get_parent(node)); - hugr.connect(w.node(), w.source(), node, tgt_port) - } - } - - fn thread_dataflow_parent( - &self, - hugr: &mut impl HugrMut, - node: N, - start_new_port_index: usize, - srcs: Vec<(Wire, Type)>, - ) -> HashMap, Wire> { - let nlocals = if srcs.is_empty() { - HashMap::new() - } else { - let (srcs, tys): (Vec<_>, Vec) = srcs.into_iter().unzip(); - let [inp, _] = hugr.get_io(node).unwrap(); - let OpType::Input(in_op) = hugr.optype_mut(inp) else { - panic!("Expected Input node") - }; - vec_insert(in_op.types.to_mut(), tys, start_new_port_index); - let new_outports = - hugr.insert_ports(inp, Direction::Outgoing, start_new_port_index, srcs.len()); - - srcs.into_iter() - .zip_eq(new_outports) - .map(|(w, p)| (w, Wire::new(inp, p))) - .collect() - }; - for ch in hugr.children(node).collect::>() { - for (inp, _) in hugr.in_value_types(ch).collect::>() { - if let Some((src_n, src_p)) = hugr.single_linked_output(ch, inp) { - if hugr.get_parent(src_n) != Some(node) { - hugr.disconnect(ch, inp); - let new_p = nlocals.get(&Wire::new(src_n, src_p)).unwrap(); - hugr.connect(new_p.node(), new_p.source(), ch, inp); - } - } - } - self.thread_node(hugr, ch, &nlocals); - } - nlocals - } - - fn thread_conditional( - &self, - hugr: &mut impl HugrMut, - node: N, - srcs: Vec<(Wire, Type)>, - ) { - for case in hugr.children(node).collect::>() { - let OpType::Case(case_op) = hugr.optype_mut(case) else { - continue; - }; - let ins = case_op.signature.input.to_mut(); - let start_case_port_index = ins.len(); - ins.extend(just_types(&srcs)); - self.thread_dataflow_parent(hugr, case, start_case_port_index, srcs.clone()); - } - } - - fn thread_tailloop( - &self, - hugr: &mut impl HugrMut, - node: N, - srcs: Vec<(Wire, Type)>, - ) { - let [_, o] = hugr.get_io(node).unwrap(); - let new_sum_row_prefixes = { - let mut v = vec![vec![]; 2]; - v[TailLoop::CONTINUE_TAG] = srcs.clone(); - v - }; - add_control_prefixes(hugr, o, new_sum_row_prefixes); - self.thread_dataflow_parent(hugr, node, 0, srcs); - } - - fn thread_bb(&self, hugr: &mut impl HugrMut, node: N) { - let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { - panic!("Expected dataflow block") - }; - let my_inputs: Vec<_> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); - vec_prepend(this_dfb.inputs.to_mut(), just_types(&my_inputs)); - let locals = self.thread_dataflow_parent(hugr, node, 0, my_inputs); - let variant_source_prefixes: Vec, Type)>> = hugr - .output_neighbours(node) - .map(|succ| { - // The wires required for each successor block, should be available in the predecessor - self.get(succ) - .map(|(w, ty)| { - ( - if hugr.get_parent(w.node()) == Some(node) { - *w - } else { - *locals.get(w).unwrap() - }, - ty.clone(), - ) - }) - .collect() - }) - .collect(); - let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { - panic!("It worked earlier!") - }; - for (source_prefix, sum_row) in variant_source_prefixes - .iter() - .zip_eq(this_dfb.sum_rows.iter_mut()) - { - vec_prepend(sum_row.to_mut(), just_types(source_prefix)); - } - let [_, output_node] = hugr.get_io(node).unwrap(); - add_control_prefixes(hugr, output_node, variant_source_prefixes); - } -} - -fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { - v.into_iter().map(|(_, t)| t.clone()) -} - #[deprecated(note = "Use FindNonLocalEdgesError")] pub type NonLocalEdgesError = FindNonLocalEdgesError; @@ -354,12 +88,6 @@ pub fn ensure_no_nonlocal_edges( } } -#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] -#[non_exhaustive] -pub enum LocalizeEdgesError { - HugrError(#[from] HugrError), -} - #[derive(Debug, Clone)] struct WorkItem { source: Wire, @@ -367,93 +95,8 @@ struct WorkItem { ty: Type, } -/// `variant_source_prefixes` are extra wires/types to prepend onto each variant -/// (must have one element per variant of control Sum) -fn add_control_prefixes( - hugr: &mut H, - output_node: H::Node, - variant_source_prefixes: Vec, Type)>>, -) { - debug_assert!(hugr.get_optype(output_node).is_output()); // Just to fail fast - let parent = hugr.get_parent(output_node).unwrap(); - let mut needed_sources = BTreeMap::new(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = hugr - .get_optype(output_node) - .port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; - - let mut type_for_source = |source: &(Wire, Type)| { - let (w, t) = source; - let replaced = needed_sources.insert(*w, (*w, t.clone())); - debug_assert!(!replaced.is_some_and(|x| x != (*w, t.clone()))); - t.clone() - }; - let old_sum_rows: Vec = sum_type - .variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows: Vec = - itertools::zip_eq(variant_source_prefixes.iter(), old_sum_rows.iter()) - .map(|(new_sources, old_tys)| { - new_sources - .iter() - .map(&mut type_for_source) - .chain(old_tys.iter().cloned()) - .collect_vec() - .into() - }) - .collect_vec(); - - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = ConditionalBuilder::new( - old_sum_rows.clone(), - just_types(needed_sources.values()).collect_vec(), - new_control_type.clone(), - ) - .unwrap(); - for (i, new_sources) in variant_source_prefixes.into_iter().enumerate() { - let mut case = cond.case_builder(i).unwrap(); - let case_inputs = case.input_wires().collect_vec(); - let mut args = new_sources - .into_iter() - .map(|(s, _ty)| { - case_inputs[old_sum_rows[i].len() - + needed_sources - .iter() - .find_position(|(w, _)| **w == s) - .unwrap() - .0] - }) - .collect_vec(); - args.extend(&case_inputs[..old_sum_rows[i].len()]); - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); - } - (cond.finish_hugr().unwrap(), new_control_type) - }; - let cond_node = hugr.insert_hugr(parent, cond).inserted_entrypoint; - let (old_output_source_node, old_output_source_port) = - hugr.single_linked_output(output_node, 0).unwrap(); - debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); - hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); - for (i, &(w, _)) in needed_sources.values().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); - } - hugr.disconnect(output_node, IncomingPort::from(0)); - hugr.connect(cond_node, 0, output_node, 0); - let OpType::Output(output) = hugr.optype_mut(output_node) else { - panic!("impossible") - }; - output.types.to_mut()[0] = new_control_type; +fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { + v.into_iter().map(|(_, t)| t.clone()) } pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { @@ -513,21 +156,11 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg } } - bb_needs_sources_map.thread_node(hugr, hugr.entrypoint(), &HashMap::new()); + bb_needs_sources_map.thread_hugr(hugr); Ok(()) } -fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { - vec_insert(v, ts, 0) -} - -fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { - let mut old_v_iter = mem::take(v).into_iter(); - v.extend(old_v_iter.by_ref().take(index).chain(ts)); - v.extend(old_v_iter); -} - #[cfg(test)] mod test { use hugr_core::{ @@ -540,20 +173,6 @@ mod test { use super::*; - #[test] - fn vec_insert0() { - let mut v = vec![5, 7, 9]; - vec_insert(&mut v, [1, 2], 0); - assert_eq!(v, [1, 2, 5, 7, 9]); - } - - #[test] - fn vec_insert1() { - let mut v = vec![5, 7, 9]; - vec_insert(&mut v, [1, 2], 1); - assert_eq!(v, [5, 1, 2, 7, 9]); - } - #[test] fn ensures_no_nonlocal_edges() { let hugr = { diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs new file mode 100644 index 0000000000..9b4d623063 --- /dev/null +++ b/hugr-passes/src/non_local/localize.rs @@ -0,0 +1,404 @@ +//! Implementation of [super::LocalizeEdgesPass] + +use std::collections::{BTreeMap, HashMap}; + +use delegate::delegate; + +use hugr_core::{ + Direction, HugrView, IncomingPort, Wire, + builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, + core::HugrNode, + hugr::hugrmut::HugrMut, + ops::{DataflowOpTrait, OpType, Tag, TailLoop}, + types::{EdgeKind, Type, TypeRow}, +}; +use itertools::{Either, Itertools}; + +use super::just_types; + +// Analysis: determining all extra ports that must be added ============================= +#[derive(Debug, Clone)] +// Map from (parent of target node) to source Wire to Type. +// `BB` is any container, not necessarily a Basic Block or in a CFG +pub struct BBNeedsSourcesMap(BTreeMap, Type>>); + +impl Default for BBNeedsSourcesMap { + fn default() -> Self { + Self(BTreeMap::default()) + } +} + +impl BBNeedsSourcesMap { + fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { + self.0.entry(node).or_default().insert(source, ty).is_none() + } + + pub(super) fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { + match self.0.get(&node) { + Some(x) => Either::Left(x.iter()), + None => Either::Right(std::iter::empty()), + } + } + + delegate! { + to self.0 { + pub(super) fn keys(&self) -> impl Iterator; + } + } +} + +#[derive(Debug, Clone)] +pub struct BBNeedsSourcesMapBuilder { + hugr: H, + needs_sources: BBNeedsSourcesMap, +} + +impl BBNeedsSourcesMapBuilder { + pub fn new(hugr: H) -> Self { + Self { + hugr, + needs_sources: Default::default(), + } + } + + pub fn insert(&mut self, mut parent: H::Node, source: Wire, ty: Type) { + let source_parent = self.hugr.get_parent(source.node()).unwrap(); + while source_parent != parent { + if !self.needs_sources.insert(parent, source, ty.clone()) { + break; + } + if self.hugr.get_optype(parent).is_conditional() { + // One of these we must have just done on the previous iteration + for case in self.hugr.children(parent) { + // Full recursion unnecessary as we've just added parent: + self.needs_sources.insert(case, source, ty.clone()); + } + } + // this will panic if source_parent is not an ancestor of target + let parent_parent = self.hugr.get_parent(parent).unwrap(); + if self.hugr.get_optype(parent).is_dataflow_block() { + assert!(self.hugr.get_optype(parent_parent).is_cfg()); + for pred in self.hugr.input_neighbours(parent).collect::>() { + self.insert(pred, source, ty.clone()); + } + if Some(parent) != self.hugr.children(parent_parent).next() { + // Recursive calls on predecessors will have traced back to entry block + // (or source_parent itself if a dominating Basic Block) + break; + } + // We've just added to entry node - so must add to CFG as well + } + parent = parent_parent; + } + } + + pub fn finish(self) -> BBNeedsSourcesMap { + self.needs_sources + } +} + +// Transformation: adding extra ports, and wiring them up =============================== +impl BBNeedsSourcesMap { + pub(super) fn thread_hugr(&self, hugr: &mut impl HugrMut) { + self.thread_node(hugr, hugr.entrypoint(), &HashMap::new()) + } + + fn thread_node( + &self, + hugr: &mut impl HugrMut, + node: N, + locals: &HashMap, Wire>, + ) { + if self.get(node).next().is_none() { + // No edges incoming into this subtree, but there could still be nonlocal edges internal to it + for ch in hugr.children(node).collect::>() { + self.thread_node(hugr, ch, &HashMap::new()) + } + return; + } + + let sources: Vec<(Wire, Type)> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); + let src_wires: Vec> = sources.iter().map(|(w, _)| *w).collect(); + + // `match` must deal with everything inside the node, and update the signature (per OpType) + let start_new_port_index = match hugr.optype_mut(node) { + OpType::DFG(dfg) => { + let ins = dfg.signature.input.to_mut(); + let start_new_port_index = ins.len(); + ins.extend(just_types(&sources)); + + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources); + start_new_port_index + } + OpType::Conditional(cond) => { + let start_new_port_index = cond.signature().input.len(); + cond.other_inputs.to_mut().extend(just_types(&sources)); + + self.thread_conditional(hugr, node, sources); + start_new_port_index + } + OpType::TailLoop(tail_op) => { + vec_prepend(tail_op.just_inputs.to_mut(), just_types(&sources)); + self.thread_tailloop(hugr, node, sources); + 0 + } + OpType::CFG(cfg) => { + vec_prepend(cfg.signature.input.to_mut(), just_types(&sources)); + assert_eq!( + self.get(node).collect::>(), + self.get(hugr.children(node).next().unwrap()) + .collect::>() + ); // Entry node + for bb in hugr.children(node).collect::>() { + if hugr.get_optype(bb).is_dataflow_block() { + self.thread_bb(hugr, bb); + } + } + 0 + } + _ => panic!( + "All containers handled except Module/FuncDefn or root Case/DFB, which should not have incoming nonlocal edges" + ), + }; + + let new_dfg_ports = hugr.insert_ports( + node, + Direction::Incoming, + start_new_port_index, + src_wires.len(), + ); + let local_srcs = src_wires.into_iter().map(|w| *locals.get(&w).unwrap_or(&w)); + for (w, tgt_port) in local_srcs.zip_eq(new_dfg_ports) { + assert_eq!(hugr.get_parent(w.node()), hugr.get_parent(node)); + hugr.connect(w.node(), w.source(), node, tgt_port) + } + } + + fn thread_dataflow_parent( + &self, + hugr: &mut impl HugrMut, + node: N, + start_new_port_index: usize, + srcs: Vec<(Wire, Type)>, + ) -> HashMap, Wire> { + let nlocals = if srcs.is_empty() { + HashMap::new() + } else { + let (srcs, tys): (Vec<_>, Vec) = srcs.into_iter().unzip(); + let [inp, _] = hugr.get_io(node).unwrap(); + let OpType::Input(in_op) = hugr.optype_mut(inp) else { + panic!("Expected Input node") + }; + vec_insert(in_op.types.to_mut(), tys, start_new_port_index); + let new_outports = + hugr.insert_ports(inp, Direction::Outgoing, start_new_port_index, srcs.len()); + + srcs.into_iter() + .zip_eq(new_outports) + .map(|(w, p)| (w, Wire::new(inp, p))) + .collect() + }; + for ch in hugr.children(node).collect::>() { + for (inp, _) in hugr.in_value_types(ch).collect::>() { + if let Some((src_n, src_p)) = hugr.single_linked_output(ch, inp) { + if hugr.get_parent(src_n) != Some(node) { + hugr.disconnect(ch, inp); + let new_p = nlocals.get(&Wire::new(src_n, src_p)).unwrap(); + hugr.connect(new_p.node(), new_p.source(), ch, inp); + } + } + } + self.thread_node(hugr, ch, &nlocals); + } + nlocals + } + + fn thread_conditional( + &self, + hugr: &mut impl HugrMut, + node: N, + srcs: Vec<(Wire, Type)>, + ) { + for case in hugr.children(node).collect::>() { + let OpType::Case(case_op) = hugr.optype_mut(case) else { + continue; + }; + let ins = case_op.signature.input.to_mut(); + let start_case_port_index = ins.len(); + ins.extend(just_types(&srcs)); + self.thread_dataflow_parent(hugr, case, start_case_port_index, srcs.clone()); + } + } + + fn thread_tailloop( + &self, + hugr: &mut impl HugrMut, + node: N, + srcs: Vec<(Wire, Type)>, + ) { + let [_, o] = hugr.get_io(node).unwrap(); + let new_sum_row_prefixes = { + let mut v = vec![vec![]; 2]; + v[TailLoop::CONTINUE_TAG] = srcs.clone(); + v + }; + add_control_prefixes(hugr, o, new_sum_row_prefixes); + self.thread_dataflow_parent(hugr, node, 0, srcs); + } + + fn thread_bb(&self, hugr: &mut impl HugrMut, node: N) { + let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { + panic!("Expected dataflow block") + }; + let my_inputs: Vec<_> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); + vec_prepend(this_dfb.inputs.to_mut(), just_types(&my_inputs)); + let locals = self.thread_dataflow_parent(hugr, node, 0, my_inputs); + let variant_source_prefixes: Vec, Type)>> = hugr + .output_neighbours(node) + .map(|succ| { + // The wires required for each successor block, should be available in the predecessor + self.get(succ) + .map(|(w, ty)| { + ( + if hugr.get_parent(w.node()) == Some(node) { + *w + } else { + *locals.get(w).unwrap() + }, + ty.clone(), + ) + }) + .collect() + }) + .collect(); + let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { + panic!("It worked earlier!") + }; + for (source_prefix, sum_row) in variant_source_prefixes + .iter() + .zip_eq(this_dfb.sum_rows.iter_mut()) + { + vec_prepend(sum_row.to_mut(), just_types(source_prefix)); + } + let [_, output_node] = hugr.get_io(node).unwrap(); + add_control_prefixes(hugr, output_node, variant_source_prefixes); + } +} + +/// `variant_source_prefixes` are extra wires/types to prepend onto each variant +/// (must have one element per variant of control Sum) +fn add_control_prefixes( + hugr: &mut H, + output_node: H::Node, + variant_source_prefixes: Vec, Type)>>, +) { + debug_assert!(hugr.get_optype(output_node).is_output()); // Just to fail fast + let parent = hugr.get_parent(output_node).unwrap(); + let mut needed_sources = BTreeMap::new(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = hugr + .get_optype(output_node) + .port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum() else { + panic!("impossible") + }; + + let mut type_for_source = |source: &(Wire, Type)| { + let (w, t) = source; + let replaced = needed_sources.insert(*w, (*w, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (*w, t.clone()))); + t.clone() + }; + let old_sum_rows: Vec = sum_type + .variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(variant_source_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_sources, old_tys)| { + new_sources + .iter() + .map(&mut type_for_source) + .chain(old_tys.iter().cloned()) + .collect_vec() + .into() + }) + .collect_vec(); + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + just_types(needed_sources.values()).collect_vec(), + new_control_type.clone(), + ) + .unwrap(); + for (i, new_sources) in variant_source_prefixes.into_iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = new_sources + .into_iter() + .map(|(s, _ty)| { + case_inputs[old_sum_rows[i].len() + + needed_sources + .iter() + .find_position(|(w, _)| **w == s) + .unwrap() + .0] + }) + .collect_vec(); + args.extend(&case_inputs[..old_sum_rows[i].len()]); + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(parent, cond).inserted_entrypoint; + let (old_output_source_node, old_output_source_port) = + hugr.single_linked_output(output_node, 0).unwrap(); + debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); + hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); + for (i, &(w, _)) in needed_sources.values().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } + hugr.disconnect(output_node, IncomingPort::from(0)); + hugr.connect(cond_node, 0, output_node, 0); + let OpType::Output(output) = hugr.optype_mut(output_node) else { + panic!("impossible") + }; + output.types.to_mut()[0] = new_control_type; +} + +fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { + vec_insert(v, ts, 0) +} + +fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { + let mut old_v_iter = std::mem::take(v).into_iter(); + v.extend(old_v_iter.by_ref().take(index).chain(ts)); + v.extend(old_v_iter); +} + +#[cfg(test)] +mod test { + use super::vec_insert; + + #[test] + fn vec_insert0() { + let mut v = vec![5, 7, 9]; + vec_insert(&mut v, [1, 2], 0); + assert_eq!(v, [1, 2, 5, 7, 9]); + } + + #[test] + fn vec_insert1() { + let mut v = vec![5, 7, 9]; + vec_insert(&mut v, [1, 2], 1); + assert_eq!(v, [5, 1, 2, 7, 9]); + } +} From 6e27b62f2c3a545b002a353b48859d95f852a3a4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 13:18:58 +0100 Subject: [PATCH 41/61] Remove WorkItem.target as redundant (node is also key in map, port unused) --- hugr-passes/src/non_local.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index bf86019352..277edc4362 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -56,8 +56,8 @@ fn build_needs_sources_map( nonlocal_edges: &HashMap>, ) -> BBNeedsSourcesMap { let mut bnsm = BBNeedsSourcesMapBuilder::new(&hugr); - for workitem in nonlocal_edges.values() { - let parent = hugr.get_parent(workitem.target.0).unwrap(); + for (target_node, workitem) in nonlocal_edges.iter() { + let parent = hugr.get_parent(*target_node).unwrap(); debug_assert!(hugr.get_parent(parent).is_some()); bnsm.insert(parent, workitem.source, workitem.ty.clone()); } @@ -91,7 +91,6 @@ pub fn ensure_no_nonlocal_edges( #[derive(Debug, Clone)] struct WorkItem { source: Wire, - target: (N, IncomingPort), ty: Type, } @@ -106,7 +105,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg // * the type of the non-local edge. Note that all non-local edges are // value edges, so the type is well defined. let nonlocal_edges_map: HashMap<_, _> = nonlocal_edges(hugr) - .filter_map(|target @ (node, inport)| { + .filter_map(|(node, inport)| { let source = { let (n, p) = hugr.single_linked_output(node, inport)?; Wire::new(n, p) @@ -119,7 +118,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg else { panic!("impossible") }; - Some((node, WorkItem { source, target, ty })) + Some((node, WorkItem { source, ty })) }) .collect(); From fa10f7b50d4eac6c795af6903280be611fb6a139 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 13:39:50 +0100 Subject: [PATCH 42/61] Inline build_needs_sources_map, BBNeedsSourcesMapBuilder --- hugr-passes/src/non_local.rs | 53 +++++++++----------------- hugr-passes/src/non_local/localize.rs | 55 ++++++++++----------------- 2 files changed, 38 insertions(+), 70 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 277edc4362..14e814b1f6 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -2,20 +2,19 @@ //! in a Hugr and converting them to local edges. use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, core::HugrNode}; use itertools::Itertools as _; -mod localize; -use localize::{BBNeedsSourcesMap, BBNeedsSourcesMapBuilder}; - use hugr_core::{ - Wire, + HugrView, IncomingPort, Wire, hugr::{HugrError, hugrmut::HugrMut}, types::{EdgeKind, Type}, }; use crate::ComposablePass; +mod localize; +use localize::BBNeedsSourcesMap; + /// [ComposablePass] that converts all non-local edges in a Hugr /// into local ones, by inserting extra inputs to container nodes /// and extra outports to Input nodes. @@ -50,20 +49,6 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator( - hugr: impl HugrView, - nonlocal_edges: &HashMap>, -) -> BBNeedsSourcesMap { - let mut bnsm = BBNeedsSourcesMapBuilder::new(&hugr); - for (target_node, workitem) in nonlocal_edges.iter() { - let parent = hugr.get_parent(*target_node).unwrap(); - debug_assert!(hugr.get_parent(parent).is_some()); - bnsm.insert(parent, workitem.source, workitem.ty.clone()); - } - bnsm.finish() -} - #[deprecated(note = "Use FindNonLocalEdgesError")] pub type NonLocalEdgesError = FindNonLocalEdgesError; @@ -88,12 +73,6 @@ pub fn ensure_no_nonlocal_edges( } } -#[derive(Debug, Clone)] -struct WorkItem { - source: Wire, - ty: Type, -} - fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { v.into_iter().map(|(_, t)| t.clone()) } @@ -118,7 +97,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg else { panic!("impossible") }; - Some((node, WorkItem { source, ty })) + Some((node, (source, ty))) }) .collect(); @@ -127,25 +106,27 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg } // We now compute the sources needed by each parent node. - // For a given non-local edge every intermediate node in the hierarchy - // between the source's parent and the target needs that source. - let bb_needs_sources_map = build_needs_sources_map(&hugr, &nonlocal_edges_map); + let bb_needs_sources_map = { + let mut bnsm = BBNeedsSourcesMap::default(); + for (target_node, (source, ty)) in nonlocal_edges_map.iter() { + let parent = hugr.get_parent(*target_node).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.add_edge(&*hugr, parent, *source, ty.clone()); + } + bnsm + }; // TODO move this out-of-line #[cfg(debug_assertions)] { - for (&n, wi) in nonlocal_edges_map.iter() { + for (&n, (source, _)) in nonlocal_edges_map.iter() { let mut m = n; loop { let parent = hugr.get_parent(m).unwrap(); - if hugr.get_parent(wi.source.node()).unwrap() == parent { + if hugr.get_parent(source.node()).unwrap() == parent { break; } - assert!( - bb_needs_sources_map - .get(parent) - .any(|(w, _)| *w == wi.source) - ); + assert!(bb_needs_sources_map.get(parent).any(|(w, _)| w == source)); m = parent; } } diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs index 9b4d623063..eeaaef3637 100644 --- a/hugr-passes/src/non_local/localize.rs +++ b/hugr-passes/src/non_local/localize.rs @@ -45,43 +45,37 @@ impl BBNeedsSourcesMap { pub(super) fn keys(&self) -> impl Iterator; } } -} - -#[derive(Debug, Clone)] -pub struct BBNeedsSourcesMapBuilder { - hugr: H, - needs_sources: BBNeedsSourcesMap, -} - -impl BBNeedsSourcesMapBuilder { - pub fn new(hugr: H) -> Self { - Self { - hugr, - needs_sources: Default::default(), - } - } - pub fn insert(&mut self, mut parent: H::Node, source: Wire, ty: Type) { - let source_parent = self.hugr.get_parent(source.node()).unwrap(); + // Identify all required extra inputs (deals with both Dom and Ext edges). + // Every intermediate node in the hierarchy + // between the source's parent and the target needs that source. + pub(super) fn add_edge( + &mut self, + hugr: &impl HugrView, + mut parent: N, + source: Wire, + ty: Type, + ) { + let source_parent = hugr.get_parent(source.node()).unwrap(); while source_parent != parent { - if !self.needs_sources.insert(parent, source, ty.clone()) { + if !self.insert(parent, source, ty.clone()) { break; } - if self.hugr.get_optype(parent).is_conditional() { + if hugr.get_optype(parent).is_conditional() { // One of these we must have just done on the previous iteration - for case in self.hugr.children(parent) { + for case in hugr.children(parent) { // Full recursion unnecessary as we've just added parent: - self.needs_sources.insert(case, source, ty.clone()); + self.insert(case, source, ty.clone()); } } // this will panic if source_parent is not an ancestor of target - let parent_parent = self.hugr.get_parent(parent).unwrap(); - if self.hugr.get_optype(parent).is_dataflow_block() { - assert!(self.hugr.get_optype(parent_parent).is_cfg()); - for pred in self.hugr.input_neighbours(parent).collect::>() { - self.insert(pred, source, ty.clone()); + let parent_parent = hugr.get_parent(parent).unwrap(); + if hugr.get_optype(parent).is_dataflow_block() { + assert!(hugr.get_optype(parent_parent).is_cfg()); + for pred in hugr.input_neighbours(parent).collect::>() { + self.add_edge(hugr, pred, source, ty.clone()); } - if Some(parent) != self.hugr.children(parent_parent).next() { + if Some(parent) != hugr.children(parent_parent).next() { // Recursive calls on predecessors will have traced back to entry block // (or source_parent itself if a dominating Basic Block) break; @@ -92,13 +86,6 @@ impl BBNeedsSourcesMapBuilder { } } - pub fn finish(self) -> BBNeedsSourcesMap { - self.needs_sources - } -} - -// Transformation: adding extra ports, and wiring them up =============================== -impl BBNeedsSourcesMap { pub(super) fn thread_hugr(&self, hugr: &mut impl HugrMut) { self.thread_node(hugr, hugr.entrypoint(), &HashMap::new()) } From 78ad38c00d9aeef8032d6a4a03d6efc3b30fce49 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 14:00:07 +0100 Subject: [PATCH 43/61] Tidy debug-asserts, don't use delegate --- hugr-passes/src/non_local.rs | 26 +++++++++----------------- hugr-passes/src/non_local/localize.rs | 12 +++++------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 14e814b1f6..3b288dfe9a 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -116,25 +116,17 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg bnsm }; - // TODO move this out-of-line - #[cfg(debug_assertions)] - { - for (&n, (source, _)) in nonlocal_edges_map.iter() { - let mut m = n; - loop { - let parent = hugr.get_parent(m).unwrap(); - if hugr.get_parent(source.node()).unwrap() == parent { - break; - } - assert!(bb_needs_sources_map.get(parent).any(|(w, _)| w == source)); - m = parent; + debug_assert!(nonlocal_edges_map.iter().all(|(n, (source, _))| { + let mut m = *n; + loop { + let parent = hugr.get_parent(m).unwrap(); + if hugr.get_parent(source.node()).unwrap() == parent { + return true; } + assert!(bb_needs_sources_map.parent_needs(parent, *source)); + m = parent; } - - for &bb in bb_needs_sources_map.keys() { - assert!(hugr.get_parent(bb).is_some()); - } - } + })); bb_needs_sources_map.thread_hugr(hugr); diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs index eeaaef3637..d9cf7892a8 100644 --- a/hugr-passes/src/non_local/localize.rs +++ b/hugr-passes/src/non_local/localize.rs @@ -2,8 +2,6 @@ use std::collections::{BTreeMap, HashMap}; -use delegate::delegate; - use hugr_core::{ Direction, HugrView, IncomingPort, Wire, builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -33,17 +31,15 @@ impl BBNeedsSourcesMap { self.0.entry(node).or_default().insert(source, ty).is_none() } - pub(super) fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { + fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { match self.0.get(&node) { Some(x) => Either::Left(x.iter()), None => Either::Right(std::iter::empty()), } } - delegate! { - to self.0 { - pub(super) fn keys(&self) -> impl Iterator; - } + pub(super) fn parent_needs(&self, parent: N, source: Wire) -> bool { + self.get(parent).any(|(w, _)| *w == source) } // Identify all required extra inputs (deals with both Dom and Ext edges). @@ -58,6 +54,7 @@ impl BBNeedsSourcesMap { ) { let source_parent = hugr.get_parent(source.node()).unwrap(); while source_parent != parent { + debug_assert!(hugr.get_parent(parent).is_some()); if !self.insert(parent, source, ty.clone()) { break; } @@ -86,6 +83,7 @@ impl BBNeedsSourcesMap { } } + /// Threads the extra connections required throughout the Hugr pub(super) fn thread_hugr(&self, hugr: &mut impl HugrMut) { self.thread_node(hugr, hugr.entrypoint(), &HashMap::new()) } From 4681351d1b51cecac93af9bf73e6c5ef1e4c2d5a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 14:01:22 +0100 Subject: [PATCH 44/61] Don't import delegate --- Cargo.lock | 1 - hugr-passes/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 641f28d514..85bf02bfb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1325,7 +1325,6 @@ name = "hugr-passes" version = "0.20.0" dependencies = [ "ascent", - "delegate", "derive_more 1.0.0", "hugr-core", "itertools 0.14.0", diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 4dbb9c6364..4a6a006458 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -28,7 +28,6 @@ lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } -delegate.workspace = true strum = { workspace = true } [dev-dependencies] From ac7a234d696dd48d529b0b57c3bc80a14f7665d5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 14:20:59 +0100 Subject: [PATCH 45/61] LocalizeEdgesError doesn't actually need any variants, hmmm --- hugr-passes/src/non_local.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 3b288dfe9a..e9c36e143d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -6,7 +6,7 @@ use itertools::Itertools as _; use hugr_core::{ HugrView, IncomingPort, Wire, - hugr::{HugrError, hugrmut::HugrMut}, + hugr::hugrmut::HugrMut, types::{EdgeKind, Type}, }; @@ -22,9 +22,7 @@ pub struct LocalizeEdges; #[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] #[non_exhaustive] -pub enum LocalizeEdgesError { - HugrError(#[from] HugrError), -} +pub enum LocalizeEdgesError {} impl ComposablePass for LocalizeEdges { type Error = LocalizeEdgesError; From e9c0f080b05e9f21791187994850c6d78f87f117 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 14:21:35 +0100 Subject: [PATCH 46/61] comments --- hugr-passes/src/non_local.rs | 7 ++----- hugr-passes/src/non_local/localize.rs | 11 ++++++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index e9c36e143d..9b9bb06017 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -76,11 +76,8 @@ fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Ite } pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { - // First we collect all the non-local edges in the graph. We associate them to a WorkItem, which tracks: - // * the source of the non-local edge - // * the target of the non-local edge - // * the type of the non-local edge. Note that all non-local edges are - // value edges, so the type is well defined. + // Group all the non-local edges in the graph by target node, + // storing for each the source and type (well-defined as these are Value edges). let nonlocal_edges_map: HashMap<_, _> = nonlocal_edges(hugr) .filter_map(|(node, inport)| { let source = { diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs index d9cf7892a8..da8f806c3b 100644 --- a/hugr-passes/src/non_local/localize.rs +++ b/hugr-passes/src/non_local/localize.rs @@ -14,7 +14,6 @@ use itertools::{Either, Itertools}; use super::just_types; -// Analysis: determining all extra ports that must be added ============================= #[derive(Debug, Clone)] // Map from (parent of target node) to source Wire to Type. // `BB` is any container, not necessarily a Basic Block or in a CFG @@ -72,12 +71,13 @@ impl BBNeedsSourcesMap { for pred in hugr.input_neighbours(parent).collect::>() { self.add_edge(hugr, pred, source, ty.clone()); } - if Some(parent) != hugr.children(parent_parent).next() { + if Some(parent) == hugr.children(parent_parent).next() { + // We've just added to entry node - so carry on and add to CFG as well + } else { // Recursive calls on predecessors will have traced back to entry block // (or source_parent itself if a dominating Basic Block) break; } - // We've just added to entry node - so must add to CFG as well } parent = parent_parent; } @@ -88,6 +88,7 @@ impl BBNeedsSourcesMap { self.thread_node(hugr, hugr.entrypoint(), &HashMap::new()) } + // keys of `locals` are the *original* sources of the non-local edges, in self.0. fn thread_node( &self, hugr: &mut impl HugrMut, @@ -159,6 +160,7 @@ impl BBNeedsSourcesMap { } } + // Add to Input node; assume container type already updated. fn thread_dataflow_parent( &self, hugr: &mut impl HugrMut, @@ -198,6 +200,7 @@ impl BBNeedsSourcesMap { nlocals } + // Add to children (assuming conditional already updated). fn thread_conditional( &self, hugr: &mut impl HugrMut, @@ -215,6 +218,7 @@ impl BBNeedsSourcesMap { } } + // Add to body of loop (assume TailLoop node itself already updated). fn thread_tailloop( &self, hugr: &mut impl HugrMut, @@ -231,6 +235,7 @@ impl BBNeedsSourcesMap { self.thread_dataflow_parent(hugr, node, 0, srcs); } + // Add to DataflowBlock *and* inner dataflow sibling subgraph fn thread_bb(&self, hugr: &mut impl HugrMut, node: N) { let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { panic!("Expected dataflow block") From b069c1e7270e032795e45404c2614acc6239df0a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 14:50:03 +0100 Subject: [PATCH 47/61] Rewrite assertion to make clippy happy clippy objected to use of `mut` inside a debug_assert!, even though the mutated thing was local to the assert... --- hugr-passes/src/non_local.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 9b9bb06017..66f2b1518a 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -112,15 +112,12 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg }; debug_assert!(nonlocal_edges_map.iter().all(|(n, (source, _))| { - let mut m = *n; - loop { - let parent = hugr.get_parent(m).unwrap(); - if hugr.get_parent(source.node()).unwrap() == parent { - return true; - } - assert!(bb_needs_sources_map.parent_needs(parent, *source)); - m = parent; - } + std::iter::successors(Some(*n), |n| { + let parent = hugr.get_parent(*n).unwrap(); + (Some(parent) != hugr.get_parent(source.node())).then_some(parent) + }) + .skip(1) + .all(|parent| bb_needs_sources_map.parent_needs(parent, *source)) })); bb_needs_sources_map.thread_hugr(hugr); From 2e4e85c95eeb56a41544b6c5e7d08d6dd6bd433d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 17:17:16 +0100 Subject: [PATCH 48/61] CFG test includes internal (dom) edge; extend debug assertion w/ grandparent --- hugr-passes/src/non_local.rs | 233 +++++++++++++++++++++-------------- 1 file changed, 139 insertions(+), 94 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 66f2b1518a..5eb77e0205 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -112,9 +112,14 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg }; debug_assert!(nonlocal_edges_map.iter().all(|(n, (source, _))| { + let source_parent = hugr.get_parent(source.node()).unwrap(); + let source_gp = hugr.get_parent(source_parent); + let stop_at = source_gp + .map(|gp| vec![source_parent, gp]) + .unwrap_or(vec![source_parent]); std::iter::successors(Some(*n), |n| { let parent = hugr.get_parent(*n).unwrap(); - (Some(parent) != hugr.get_parent(source.node())).then_some(parent) + (!stop_at.contains(&parent)).then_some(parent) }) .skip(1) .all(|parent| bb_needs_sources_map.parent_needs(parent, *source)) @@ -309,103 +314,143 @@ mod test { // sums. // // All branches have an other-output. - let mut hugr = { - let branch_sum_type = either_type(Type::UNIT, Type::UNIT); - let branch_type = Type::from(branch_sum_type.clone()); - let branch_variants = branch_sum_type - .variants() - .cloned() - .map(|x| x.try_into().unwrap()) - .collect_vec(); - let nonlocal1_type = bool_t(); - let nonlocal2_type = Type::new_unit_sum(3); - let other_output_type = branch_type.clone(); - let mut outer = DFGBuilder::new(Signature::new( - vec![ - branch_type.clone(), - nonlocal1_type.clone(), - nonlocal2_type.clone(), - Type::UNIT, - ], - vec![Type::UNIT, other_output_type.clone()], - )) + let branch_sum_type = either_type(Type::UNIT, Type::UNIT); + let branch_type = Type::from(branch_sum_type.clone()); + let branch_variants = branch_sum_type + .variants() + .cloned() + .map(|x| x.try_into().unwrap()) + .collect_vec(); + let nonlocal1_type = bool_t(); + let nonlocal2_type = Type::new_unit_sum(3); + let other_output_type = branch_type.clone(); + let mut outer = DFGBuilder::new(Signature::new( + vec![branch_type.clone(), nonlocal1_type.clone(), Type::UNIT], + vec![Type::UNIT, other_output_type.clone()], + )) + .unwrap(); + let [b, nl1, unit] = outer.input_wires_arr(); + let mut cfg = outer + .cfg_builder( + [(Type::UNIT, unit), (branch_type.clone(), b)], + vec![Type::UNIT, other_output_type.clone()].into(), + ) .unwrap(); - let [b, nl1, nl2, unit] = outer.input_wires_arr(); - let [unit, out] = { - let mut cfg = outer - .cfg_builder( - [(Type::UNIT, unit), (branch_type.clone(), b)], - vec![Type::UNIT, other_output_type.clone()].into(), - ) - .unwrap(); - let entry = { - let entry = cfg - .entry_builder(branch_variants.clone(), other_output_type.clone().into()) - .unwrap(); - let [_, b] = entry.input_wires_arr(); - - entry.finish_with_outputs(b, [b]).unwrap() - }; - let exit = cfg.exit_block(); - - let bb_left = { - let mut entry = cfg - .block_builder( - vec![Type::UNIT, other_output_type.clone()].into(), - [type_row![]], - other_output_type.clone().into(), - ) - .unwrap(); - let [unit, oo] = entry.input_wires_arr(); - let [_] = entry - .add_dataflow_op(Noop::new(nonlocal1_type), [nl1]) - .unwrap() - .outputs_arr(); - let [_] = entry - .add_dataflow_op(Noop::new(nonlocal2_type), [nl2]) - .unwrap() - .outputs_arr(); - entry.finish_with_outputs(unit, [oo]).unwrap() - }; - - let bb_right = { - let entry = cfg - .block_builder( - vec![Type::UNIT, other_output_type.clone()].into(), - [type_row![]], - other_output_type.clone().into(), - ) - .unwrap(); - let [_b, oo] = entry.input_wires_arr(); - entry.finish_with_outputs(unit, [oo]).unwrap() - }; - - let bb_bottom = { - let entry = cfg - .block_builder( - branch_type.clone().into(), - branch_variants, - other_output_type.clone().into(), - ) - .unwrap(); - let [oo] = entry.input_wires_arr(); - entry.finish_with_outputs(oo, [oo]).unwrap() - }; - cfg.branch(&entry, 0, &bb_left).unwrap(); - cfg.branch(&entry, 1, &bb_right).unwrap(); - cfg.branch(&bb_left, 0, &bb_bottom).unwrap(); - cfg.branch(&bb_right, 0, &bb_bottom).unwrap(); - cfg.branch(&bb_bottom, 0, &entry).unwrap(); - cfg.branch(&bb_bottom, 1, &exit).unwrap(); - cfg.finish_sub_container().unwrap().outputs_arr() - }; - outer.finish_hugr_with_outputs([unit, out]).unwrap() + let (entry, cst) = { + let mut entry = cfg + .entry_builder(branch_variants.clone(), other_output_type.clone().into()) + .unwrap(); + let [_, b] = entry.input_wires_arr(); + + let cst = entry.add_load_value(Value::unit_sum(1, 3).unwrap()); + + (entry.finish_with_outputs(b, [b]).unwrap(), cst) }; - assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + let exit = cfg.exit_block(); + + let (bb_left, tgt_ext, tgt_dom) = { + let mut bb = cfg + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) + .unwrap(); + let [unit, oo] = bb.input_wires_arr(); + let tgt_ext = bb + .add_dataflow_op(Noop::new(nonlocal1_type.clone()), [nl1]) + .unwrap(); + + let tgt_dom = bb + .add_dataflow_op(Noop::new(nonlocal2_type.clone()), [cst]) + .unwrap(); + ( + bb.finish_with_outputs(unit, [oo]).unwrap(), + tgt_ext, + tgt_dom, + ) + }; + + let bb_right = { + let mut bb = cfg + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) + .unwrap(); + let [_b, oo] = bb.input_wires_arr(); + let unit = bb.add_load_value(Value::unit()); + bb.finish_with_outputs(unit, [oo]).unwrap() + }; + + let bb_bottom = { + let bb = cfg + .block_builder( + branch_type.clone().into(), + branch_variants, + other_output_type.clone().into(), + ) + .unwrap(); + let [oo] = bb.input_wires_arr(); + bb.finish_with_outputs(oo, [oo]).unwrap() + }; + cfg.branch(&entry, 0, &bb_left).unwrap(); + cfg.branch(&entry, 1, &bb_right).unwrap(); + cfg.branch(&bb_left, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_right, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_bottom, 0, &entry).unwrap(); + cfg.branch(&bb_bottom, 1, &exit).unwrap(); + let [unit, out] = cfg.finish_sub_container().unwrap().outputs_arr(); + + let mut hugr = outer.finish_hugr_with_outputs([unit, out]).unwrap(); + eprintln!("ALAN tgt_ext {tgt_ext:?} tgt_dom {tgt_dom:?}"); + eprintln!("{}", hugr.mermaid_string()); + let Err(FindNonLocalEdgesError::Edges(es)) = ensure_no_nonlocal_edges(&hugr) else { + panic!() + }; + assert_eq!( + es, + vec![ + (tgt_ext.node(), IncomingPort::from(0)), + (tgt_dom.node(), IncomingPort::from(0)) + ] + ); remove_nonlocal_edges(&mut hugr).unwrap(); - println!("{}", hugr.mermaid_string()); - hugr.validate().unwrap_or_else(|e| panic!("{e}")); + hugr.validate().unwrap(); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + // Entry node gets nonlocal1_type added, only + assert_eq!( + hugr.get_optype(entry.node()) + .as_dataflow_block() + .unwrap() + .inputs + .as_slice(), + &[nonlocal1_type.clone(), Type::UNIT, branch_type.clone()] + ); + // Left node gets both nonlocal1_type and nonlocal2_type + assert_eq!( + hugr.get_optype(bb_left.node()) + .as_dataflow_block() + .unwrap() + .inputs + .as_slice(), + &[ + nonlocal1_type.clone(), + nonlocal2_type, + Type::UNIT, + other_output_type + ] + ); + // Bottom node gets nonlocal1_type added, only + assert_eq!( + hugr.get_optype(bb_bottom.node()) + .as_dataflow_block() + .unwrap() + .inputs + .as_slice(), + &[nonlocal1_type, branch_type] + ); } } From 59f1c890a474902820daeed1ac882b8a09a34ce2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 20 May 2025 17:25:56 +0100 Subject: [PATCH 49/61] Tidy checks --- hugr-passes/src/non_local.rs | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 5eb77e0205..d2325d7cc8 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -135,9 +135,10 @@ mod test { use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, extension::prelude::{Noop, bool_t, either_type}, - ops::{Tag, TailLoop, Value, handle::NodeHandle}, + ops::handle::{BasicBlockID, NodeHandle}, + ops::{Tag, TailLoop, Value}, type_row, - types::Signature, + types::{Signature, TypeRow}, }; use super::*; @@ -420,23 +421,16 @@ mod test { remove_nonlocal_edges(&mut hugr).unwrap(); hugr.validate().unwrap(); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + let dfb = |bb: BasicBlockID| hugr.get_optype(bb.node()).as_dataflow_block().unwrap(); // Entry node gets nonlocal1_type added, only assert_eq!( - hugr.get_optype(entry.node()) - .as_dataflow_block() - .unwrap() - .inputs - .as_slice(), - &[nonlocal1_type.clone(), Type::UNIT, branch_type.clone()] + dfb(entry).inputs[..], + [nonlocal1_type.clone(), Type::UNIT, branch_type.clone()] ); // Left node gets both nonlocal1_type and nonlocal2_type assert_eq!( - hugr.get_optype(bb_left.node()) - .as_dataflow_block() - .unwrap() - .inputs - .as_slice(), - &[ + dfb(bb_left).inputs[..], + [ nonlocal1_type.clone(), nonlocal2_type, Type::UNIT, @@ -444,13 +438,6 @@ mod test { ] ); // Bottom node gets nonlocal1_type added, only - assert_eq!( - hugr.get_optype(bb_bottom.node()) - .as_dataflow_block() - .unwrap() - .inputs - .as_slice(), - &[nonlocal1_type, branch_type] - ); + assert_eq!(dfb(bb_bottom).inputs[..], [nonlocal1_type, branch_type]); } } From e2f11d465497f42d035ce1d5805e760c005e93f5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 21 May 2025 09:08:11 +0100 Subject: [PATCH 50/61] Rename BBNeedsSourceMap to ExtraSourceReqs, a few comments --- hugr-passes/src/non_local.rs | 6 +++--- hugr-passes/src/non_local/localize.rs | 26 +++++++++++++++----------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index d2325d7cc8..10f54f687e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -13,7 +13,7 @@ use hugr_core::{ use crate::ComposablePass; mod localize; -use localize::BBNeedsSourcesMap; +use localize::ExtraSourceReqs; /// [ComposablePass] that converts all non-local edges in a Hugr /// into local ones, by inserting extra inputs to container nodes @@ -102,7 +102,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg // We now compute the sources needed by each parent node. let bb_needs_sources_map = { - let mut bnsm = BBNeedsSourcesMap::default(); + let mut bnsm = ExtraSourceReqs::default(); for (target_node, (source, ty)) in nonlocal_edges_map.iter() { let parent = hugr.get_parent(*target_node).unwrap(); debug_assert!(hugr.get_parent(parent).is_some()); @@ -138,7 +138,7 @@ mod test { ops::handle::{BasicBlockID, NodeHandle}, ops::{Tag, TailLoop, Value}, type_row, - types::{Signature, TypeRow}, + types::Signature, }; use super::*; diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs index da8f806c3b..562f352e3a 100644 --- a/hugr-passes/src/non_local/localize.rs +++ b/hugr-passes/src/non_local/localize.rs @@ -15,17 +15,17 @@ use itertools::{Either, Itertools}; use super::just_types; #[derive(Debug, Clone)] -// Map from (parent of target node) to source Wire to Type. -// `BB` is any container, not necessarily a Basic Block or in a CFG -pub struct BBNeedsSourcesMap(BTreeMap, Type>>); +// For each parent/container node, a map from the source Wires that need to be added +// as extra inputs to that container, to the Type of each. +pub(super) struct ExtraSourceReqs(BTreeMap, Type>>); -impl Default for BBNeedsSourcesMap { +impl Default for ExtraSourceReqs { fn default() -> Self { Self(BTreeMap::default()) } } -impl BBNeedsSourcesMap { +impl ExtraSourceReqs { fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { self.0.entry(node).or_default().insert(source, ty).is_none() } @@ -37,14 +37,14 @@ impl BBNeedsSourcesMap { } } - pub(super) fn parent_needs(&self, parent: N, source: Wire) -> bool { + pub fn parent_needs(&self, parent: N, source: Wire) -> bool { self.get(parent).any(|(w, _)| *w == source) } - // Identify all required extra inputs (deals with both Dom and Ext edges). - // Every intermediate node in the hierarchy - // between the source's parent and the target needs that source. - pub(super) fn add_edge( + /// Identify all required extra inputs (deals with both Dom and Ext edges). + /// Every intermediate node in the hierarchy + /// between the source's parent and the target needs that source. + pub fn add_edge( &mut self, hugr: &impl HugrView, mut parent: N, @@ -64,10 +64,14 @@ impl BBNeedsSourcesMap { self.insert(case, source, ty.clone()); } } - // this will panic if source_parent is not an ancestor of target + // this will eventually panic if source_parent is not an ancestor of target let parent_parent = hugr.get_parent(parent).unwrap(); + if hugr.get_optype(parent).is_dataflow_block() { assert!(hugr.get_optype(parent_parent).is_cfg()); + // For both Dom edges and Ext edges from outside the CFG, also add to all + // reaching BBs (for a Dom edge, up to but not including the source BB: + // all paths eventually come from the source since it dominates the target). for pred in hugr.input_neighbours(parent).collect::>() { self.add_edge(hugr, pred, source, ty.clone()); } From ceaa552e6a863ff56f635534db9eb92a70415638 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 23 May 2025 11:53:43 +0100 Subject: [PATCH 51/61] Remove unnecessary filter; HashMap -> Vec so no dedup --- hugr-passes/src/non_local.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 10f54f687e..44f5ccbe3d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,7 +1,5 @@ //! This module provides functions for finding non-local edges //! in a Hugr and converting them to local edges. -use std::collections::HashMap; - use itertools::Itertools as _; use hugr_core::{ @@ -78,10 +76,11 @@ fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Ite pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { // Group all the non-local edges in the graph by target node, // storing for each the source and type (well-defined as these are Value edges). - let nonlocal_edges_map: HashMap<_, _> = nonlocal_edges(hugr) - .filter_map(|(node, inport)| { + let nonlocal_edges: Vec<_> = nonlocal_edges(hugr) + .map(|(node, inport)| { let source = { - let (n, p) = hugr.single_linked_output(node, inport)?; + // unwrap because nonlocal_edges(hugr) already skips in-ports with !=1 linked outputs. + let (n, p) = hugr.single_linked_output(node, inport).unwrap(); Wire::new(n, p) }; debug_assert!( @@ -92,18 +91,18 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg else { panic!("impossible") }; - Some((node, (source, ty))) + (node, (source, ty)) }) .collect(); - if nonlocal_edges_map.is_empty() { + if nonlocal_edges.is_empty() { return Ok(()); } // We now compute the sources needed by each parent node. let bb_needs_sources_map = { let mut bnsm = ExtraSourceReqs::default(); - for (target_node, (source, ty)) in nonlocal_edges_map.iter() { + for (target_node, (source, ty)) in nonlocal_edges.iter() { let parent = hugr.get_parent(*target_node).unwrap(); debug_assert!(hugr.get_parent(parent).is_some()); bnsm.add_edge(&*hugr, parent, *source, ty.clone()); @@ -111,7 +110,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg bnsm }; - debug_assert!(nonlocal_edges_map.iter().all(|(n, (source, _))| { + debug_assert!(nonlocal_edges.iter().all(|(n, (source, _))| { let source_parent = hugr.get_parent(source.node()).unwrap(); let source_gp = hugr.get_parent(source_parent); let stop_at = source_gp From aa2daf81028d56569c075d9789e351b408a95242 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 23 May 2025 12:01:16 +0100 Subject: [PATCH 52/61] Docs --- hugr-passes/src/non_local.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 44f5ccbe3d..599b48bc3c 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,5 +1,6 @@ //! This module provides functions for finding non-local edges //! in a Hugr and converting them to local edges. +#![warn(missing_docs)] use itertools::Itertools as _; use hugr_core::{ @@ -13,11 +14,10 @@ use crate::ComposablePass; mod localize; use localize::ExtraSourceReqs; -/// [ComposablePass] that converts all non-local edges in a Hugr -/// into local ones, by inserting extra inputs to container nodes -/// and extra outports to Input nodes. +/// [ComposablePass] wrapper for [remove_nonlocal_edges] pub struct LocalizeEdges; +/// Error from [LocalizeEdges] or [remove_nonlocal_edges] #[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] #[non_exhaustive] pub enum LocalizeEdgesError {} @@ -45,6 +45,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator = FindNonLocalEdgesError; @@ -52,8 +53,9 @@ pub type NonLocalEdgesError = FindNonLocalEdgesError; #[derive(Clone, derive_more::Error, derive_more::Display, Debug, PartialEq, Eq)] #[non_exhaustive] pub enum FindNonLocalEdgesError { + /// Some nonlocal edges were found #[display("Found {} nonlocal edges", _0.len())] - #[error(ignore)] + #[error(ignore)] // Vec not convertible Edges(Vec<(N, IncomingPort)>), } @@ -73,6 +75,11 @@ fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Ite v.into_iter().map(|(_, t)| t.clone()) } +/// Converts all non-local edges in a Hugr into local ones, by inserting extra inputs +/// to container nodes and extra outports to Input nodes (and conversely to outputs of +/// [DataflowBlock]s). +/// +/// [DataflowBlock]: hugr_core::ops::DataflowBlock pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { // Group all the non-local edges in the graph by target node, // storing for each the source and type (well-defined as these are Value edges). @@ -100,7 +107,7 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg } // We now compute the sources needed by each parent node. - let bb_needs_sources_map = { + let needs_sources_map = { let mut bnsm = ExtraSourceReqs::default(); for (target_node, (source, ty)) in nonlocal_edges.iter() { let parent = hugr.get_parent(*target_node).unwrap(); @@ -121,10 +128,10 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg (!stop_at.contains(&parent)).then_some(parent) }) .skip(1) - .all(|parent| bb_needs_sources_map.parent_needs(parent, *source)) + .all(|parent| needs_sources_map.parent_needs(parent, *source)) })); - bb_needs_sources_map.thread_hugr(hugr); + needs_sources_map.thread_hugr(hugr); Ok(()) } From b4417f379061df258d28b2f6b73d2be32abbd238 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 23 May 2025 12:04:45 +0100 Subject: [PATCH 53/61] Improve ExtraSourceReqs::get --- hugr-passes/src/non_local/localize.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs index 562f352e3a..15f5dedc09 100644 --- a/hugr-passes/src/non_local/localize.rs +++ b/hugr-passes/src/non_local/localize.rs @@ -10,7 +10,7 @@ use hugr_core::{ ops::{DataflowOpTrait, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, }; -use itertools::{Either, Itertools}; +use itertools::Itertools as _; use super::just_types; @@ -31,10 +31,7 @@ impl ExtraSourceReqs { } fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { - match self.0.get(&node) { - Some(x) => Either::Left(x.iter()), - None => Either::Right(std::iter::empty()), - } + self.0.get(&node).into_iter().flat_map(BTreeMap::iter) } pub fn parent_needs(&self, parent: N, source: Wire) -> bool { From 499dd69e2006eaf812c5afdf39cf73f0b073869f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 09:00:13 +0100 Subject: [PATCH 54/61] derives for LocalizeEdges --- hugr-passes/src/non_local.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 599b48bc3c..fd2008e54c 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -15,6 +15,7 @@ mod localize; use localize::ExtraSourceReqs; /// [ComposablePass] wrapper for [remove_nonlocal_edges] +#[derive(Clone, Debug, Hash)] pub struct LocalizeEdges; /// Error from [LocalizeEdges] or [remove_nonlocal_edges] From d1d01f9a48ae8f115cb2d72cbd2b5691a7bc67d3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 09:00:37 +0100 Subject: [PATCH 55/61] simplify map func by creating Wire later --- hugr-passes/src/non_local.rs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fd2008e54c..5d43fb22dc 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -86,20 +86,13 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg // storing for each the source and type (well-defined as these are Value edges). let nonlocal_edges: Vec<_> = nonlocal_edges(hugr) .map(|(node, inport)| { - let source = { - // unwrap because nonlocal_edges(hugr) already skips in-ports with !=1 linked outputs. - let (n, p) = hugr.single_linked_output(node, inport).unwrap(); - Wire::new(n, p) - }; - debug_assert!( - hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() - ); - let Some(EdgeKind::Value(ty)) = - hugr.get_optype(source.node()).port_kind(source.source()) - else { + // unwrap because nonlocal_edges(hugr) already skips in-ports with !=1 linked outputs. + let (src_n, outp) = hugr.single_linked_output(node, inport).unwrap(); + debug_assert!(hugr.get_parent(src_n).unwrap() != hugr.get_parent(node).unwrap()); + let Some(EdgeKind::Value(ty)) = hugr.get_optype(src_n).port_kind(outp) else { panic!("impossible") }; - (node, (source, ty)) + (node, (Wire::new(src_n, outp), ty)) }) .collect(); From c8de933d7e1793bb6815c2b0425f14321e4c00e2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 09:32:51 +0100 Subject: [PATCH 56/61] Simplify debug-assert, filter nonlocal_edges --- hugr-passes/src/non_local.rs | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 5d43fb22dc..93290da7c1 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -33,15 +33,19 @@ impl ComposablePass for LocalizeEdges { } } -/// Returns an iterator over all non local edges in a Hugr. +/// Returns an iterator over all non local edges in a Hugr beneath the entrypoint. /// -/// All `(node, in_port)` pairs are returned where `in_port` is a value port -/// connected to a node with a parent other than the parent of `node`. +/// All `(node, in_port)` pairs are returned where `in_port` is a value port connected +/// to a node beneath the entrypoint but with a parent other than the parent of `node`. pub fn nonlocal_edges(hugr: &H) -> impl Iterator + '_ { hugr.entry_descendants().flat_map(move |node| { hugr.in_value_types(node).filter_map(move |(in_p, _)| { let (src, _) = hugr.single_linked_output(node, in_p)?; - (hugr.get_parent(node) != hugr.get_parent(src)).then_some((node, in_p)) + (hugr.get_parent(node) != hugr.get_parent(src) + && ancestors(src, hugr) + .find(|a| *a == hugr.entrypoint()) + .is_some()) + .then_some((node, in_p)) }) }) } @@ -114,15 +118,10 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg debug_assert!(nonlocal_edges.iter().all(|(n, (source, _))| { let source_parent = hugr.get_parent(source.node()).unwrap(); let source_gp = hugr.get_parent(source_parent); - let stop_at = source_gp - .map(|gp| vec![source_parent, gp]) - .unwrap_or(vec![source_parent]); - std::iter::successors(Some(*n), |n| { - let parent = hugr.get_parent(*n).unwrap(); - (!stop_at.contains(&parent)).then_some(parent) - }) - .skip(1) - .all(|parent| needs_sources_map.parent_needs(parent, *source)) + ancestors(*n, hugr) + .skip(1) + .take_while(|&a| a != source_parent && source_gp.is_none_or(|gp| a != gp)) + .all(|parent| needs_sources_map.parent_needs(parent, *source)) })); needs_sources_map.thread_hugr(hugr); @@ -130,6 +129,10 @@ pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdg Ok(()) } +fn ancestors(n: H::Node, h: &H) -> impl Iterator { + std::iter::successors(Some(n), |n| h.get_parent(*n)) +} + #[cfg(test)] mod test { use hugr_core::{ From 6f1875c3a6e9a70bc255a7c355b241d15e3c2581 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 10:40:36 +0100 Subject: [PATCH 57/61] improve comment on nonlocal_edges more --- hugr-passes/src/non_local.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 93290da7c1..521c1b4c07 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -35,8 +35,8 @@ impl ComposablePass for LocalizeEdges { /// Returns an iterator over all non local edges in a Hugr beneath the entrypoint. /// -/// All `(node, in_port)` pairs are returned where `in_port` is a value port connected -/// to a node beneath the entrypoint but with a parent other than the parent of `node`. +/// All `(node, in_port)` pairs are returned where `in_port` is a value port connected to a +/// node whose parent is both beneath the entrypoint and different from the parent of `node`. pub fn nonlocal_edges(hugr: &H) -> impl Iterator + '_ { hugr.entry_descendants().flat_map(move |node| { hugr.in_value_types(node).filter_map(move |(in_p, _)| { From 6784b9d3988e41524d81c1bf4891c13352fa907c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 10:29:57 +0100 Subject: [PATCH 58/61] Clarify localize_cfg, remove debug printout --- hugr-passes/src/non_local.rs | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 521c1b4c07..a6b357622d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -310,8 +310,9 @@ mod test { // The 4 dataflow blocks form a diamond, and the bottom block branches // either to the entry block or the exit block. // - // Two non-local uses in the left block means that these values must - // be threaded through all blocks, because of the loop. + // The left block contains non-local uses of a value from outside the CFG (ext edge) + // and a value from the entry block (dom edge) - the `ext` must be threaded through + // all blocks because of the loop, the `dom` stays within (the same iter of) the loop. // // All non-trivial(i.e. more than one choice of successor) branching is // done on an option type to exercise both empty and occupied control @@ -325,15 +326,15 @@ mod test { .cloned() .map(|x| x.try_into().unwrap()) .collect_vec(); - let nonlocal1_type = bool_t(); - let nonlocal2_type = Type::new_unit_sum(3); + let ext_edge_type = bool_t(); + let dom_edge_type = Type::new_unit_sum(3); let other_output_type = branch_type.clone(); let mut outer = DFGBuilder::new(Signature::new( - vec![branch_type.clone(), nonlocal1_type.clone(), Type::UNIT], + vec![branch_type.clone(), ext_edge_type.clone(), Type::UNIT], vec![Type::UNIT, other_output_type.clone()], )) .unwrap(); - let [b, nl1, unit] = outer.input_wires_arr(); + let [b, src_ext, unit] = outer.input_wires_arr(); let mut cfg = outer .cfg_builder( [(Type::UNIT, unit), (branch_type.clone(), b)], @@ -341,7 +342,7 @@ mod test { ) .unwrap(); - let (entry, cst) = { + let (entry, src_dom) = { let mut entry = cfg .entry_builder(branch_variants.clone(), other_output_type.clone().into()) .unwrap(); @@ -363,11 +364,11 @@ mod test { .unwrap(); let [unit, oo] = bb.input_wires_arr(); let tgt_ext = bb - .add_dataflow_op(Noop::new(nonlocal1_type.clone()), [nl1]) + .add_dataflow_op(Noop::new(ext_edge_type.clone()), [src_ext]) .unwrap(); let tgt_dom = bb - .add_dataflow_op(Noop::new(nonlocal2_type.clone()), [cst]) + .add_dataflow_op(Noop::new(dom_edge_type.clone()), [src_dom]) .unwrap(); ( bb.finish_with_outputs(unit, [oo]).unwrap(), @@ -409,8 +410,6 @@ mod test { let [unit, out] = cfg.finish_sub_container().unwrap().outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([unit, out]).unwrap(); - eprintln!("ALAN tgt_ext {tgt_ext:?} tgt_dom {tgt_dom:?}"); - eprintln!("{}", hugr.mermaid_string()); let Err(FindNonLocalEdgesError::Edges(es)) = ensure_no_nonlocal_edges(&hugr) else { panic!() }; @@ -425,22 +424,22 @@ mod test { hugr.validate().unwrap(); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); let dfb = |bb: BasicBlockID| hugr.get_optype(bb.node()).as_dataflow_block().unwrap(); - // Entry node gets nonlocal1_type added, only + // Entry node gets ext_edge_type added, only assert_eq!( dfb(entry).inputs[..], - [nonlocal1_type.clone(), Type::UNIT, branch_type.clone()] + [ext_edge_type.clone(), Type::UNIT, branch_type.clone()] ); - // Left node gets both nonlocal1_type and nonlocal2_type + // Left node gets both ext_edge_type and dom_edge_type assert_eq!( dfb(bb_left).inputs[..], [ - nonlocal1_type.clone(), - nonlocal2_type, + ext_edge_type.clone(), + dom_edge_type, Type::UNIT, other_output_type ] ); - // Bottom node gets nonlocal1_type added, only - assert_eq!(dfb(bb_bottom).inputs[..], [nonlocal1_type, branch_type]); + // Bottom node gets ext_edge_type added, only + assert_eq!(dfb(bb_bottom).inputs[..], [ext_edge_type, branch_type]); } } From 6ae84515fac892c26fee252bed5782c6560b9547 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 10:30:15 +0100 Subject: [PATCH 59/61] Extend localize_dfg test to cover multiple edges to same node --- hugr-passes/src/non_local.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index a6b357622d..fe92b8cab6 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -143,6 +143,7 @@ mod test { type_row, types::Signature, }; + use rstest::rstest; use super::*; @@ -189,18 +190,25 @@ mod test { ); } - #[test] - fn localize_dfg() { + #[rstest] + fn localize_dfg(#[values(true, false)] same_src: bool) { let mut hugr = { - let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); - let [w0] = outer.input_wires_arr(); - let [w1] = { + let mut outer = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); + let [w0, mut w1] = outer.input_wires_arr(); + if !same_src { + [w1] = outer + .add_dataflow_op(Noop::new(bool_t()), [w1]) + .unwrap() + .outputs_arr(); + } + let inner_outs = { let inner = outer - .dfg_builder(Signature::new_endo(bool_t()), [w0]) + .dfg_builder(Signature::new(vec![], vec![bool_t(); 2]), []) .unwrap(); - inner.finish_with_outputs([w0]).unwrap().outputs_arr() + // Note two `ext` edges to the same (Input) node here + inner.finish_with_outputs([w0, w1]).unwrap().outputs() }; - outer.finish_hugr_with_outputs([w1]).unwrap() + outer.finish_hugr_with_outputs(inner_outs).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); remove_nonlocal_edges(&mut hugr).unwrap(); From b336b744c2e8a7ca89c87f4b8e6a583fdc4cc06b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 May 2025 10:42:10 +0100 Subject: [PATCH 60/61] clippy --- hugr-passes/src/non_local.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fe92b8cab6..ffdc39927d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -42,9 +42,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator Date: Wed, 28 May 2025 10:43:33 +0100 Subject: [PATCH 61/61] oops --- hugr-passes/src/non_local.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index ffdc39927d..75bbea399e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -42,7 +42,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator