diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 6f4e9c8ddf..44cbe2fb97 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -24,19 +24,37 @@ pub trait HugrInternals { where Self: 'p; + /// The portgraph graph structure returned by [`HugrInternals::region_portgraph`]. + type RegionPortgraph<'p>: LinkView + Clone + 'p + where + Self: 'p; + /// The type of nodes in the Hugr. type Node: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash; + /// A mapping between HUGR nodes and portgraph nodes in the graph returned by + /// [`HugrInternals::region_portgraph`]. + type RegionPortgraphNodes: PortgraphNodeMap; + /// Returns a reference to the underlying portgraph. fn portgraph(&self) -> Self::Portgraph<'_>; - /// Returns a flat portgraph view of a region in the HUGR. - /// - /// This is a subgraph of [`HugrInternals::portgraph`], with a flat hierarchy. + /// Returns a flat portgraph view of a region in the HUGR, and a mapping between + /// HUGR nodes and portgraph nodes in the graph. + // + // NOTE: Ideally here we would just return `Self::RegionPortgraph<'_>`, but + // when doing so we are unable to restrict the type to implement petgraph's + // traits over references (e.g. `&MyGraph : IntoNodeIdentifiers`, which is + // needed if we want to use petgraph's algorithms on the region graph). + // This won't be solvable until we do the big petgraph refactor -.- + // In the meantime, just wrap the portgraph in a `FlatRegion` as needed. fn region_portgraph( &self, parent: Self::Node, - ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_>; + ) -> ( + portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>, + Self::RegionPortgraphNodes, + ); /// Returns the portgraph [Hierarchy](portgraph::Hierarchy) of the graph /// returned by [`HugrInternals::portgraph`]. @@ -65,14 +83,56 @@ pub trait HugrInternals { fn base_hugr(&self) -> &Hugr; } +/// A map between hugr nodes and portgraph nodes in the graph returned by +/// [`HugrInternals::region_portgraph`]. +pub trait PortgraphNodeMap: Clone + Sized + std::fmt::Debug { + /// Returns the portgraph index of a HUGR node in the associated region + /// graph. + /// + /// If the node is not in the region, the result is undefined. + fn to_portgraph(&self, node: N) -> portgraph::NodeIndex; + + /// Returns the HUGR node for a portgraph node in the associated region + /// graph. + /// + /// If the node is not in the region, the result is undefined. + #[allow(clippy::wrong_self_convention)] + fn from_portgraph(&self, node: portgraph::NodeIndex) -> N; +} + +/// An identity map between HUGR nodes and portgraph nodes. +#[derive( + Copy, Clone, Debug, Default, Eq, PartialEq, Hash, PartialOrd, Ord, derive_more::Display, +)] +pub struct DefaultPGNodeMap; + +impl PortgraphNodeMap for DefaultPGNodeMap { + #[inline] + fn to_portgraph(&self, node: Node) -> portgraph::NodeIndex { + node.into_portgraph() + } + + #[inline] + fn from_portgraph(&self, node: portgraph::NodeIndex) -> Node { + node.into() + } +} + impl HugrInternals for Hugr { type Portgraph<'p> = &'p MultiPortGraph where Self: 'p; + type RegionPortgraph<'p> + = &'p MultiPortGraph + where + Self: 'p; + type Node = Node; + type RegionPortgraphNodes = DefaultPGNodeMap; + #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { &self.graph @@ -82,10 +142,14 @@ impl HugrInternals for Hugr { fn region_portgraph( &self, parent: Self::Node, - ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_> { + ) -> ( + portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>, + Self::RegionPortgraphNodes, + ) { let pg = self.portgraph(); let root = self.to_portgraph_node(parent); - portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root) + let region = portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root); + (region, DefaultPGNodeMap) } #[inline] diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 5935bc3039..0dc573a9d4 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -86,7 +86,7 @@ impl<'a> ValidationContext<'a> { /// The results of this computation should be cached in `self.dominators`. /// We don't do it here to avoid mutable borrows. fn compute_dominator(&self, parent: Node) -> Dominators { - let region = self.hugr.region_portgraph(parent); + let (region, _) = self.hugr.region_portgraph(parent); let entry_node = self.hugr.children(parent).next().unwrap(); dominators::simple_fast(®ion, entry_node.into_portgraph()) } @@ -357,7 +357,7 @@ impl<'a> ValidationContext<'a> { return Ok(()); }; - let region = self.hugr.region_portgraph(parent); + let (region, _) = self.hugr.region_portgraph(parent); let postorder = Topo::new(®ion); let nodes_visited = postorder .iter(®ion) diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 7b3b78f5e5..adf4c03e65 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -12,7 +12,7 @@ macro_rules! hugr_internal_methods { delegate::delegate! { to ({let $arg=self; $e}) { fn portgraph(&self) -> Self::Portgraph<'_>; - fn region_portgraph(&self, parent: Self::Node) -> portgraph::view::FlatRegion<'_, impl portgraph::view::LinkView + Clone + '_>; + fn region_portgraph(&self, parent: Self::Node) -> (portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>, Self::RegionPortgraphNodes); fn hierarchy(&self) -> &portgraph::Hierarchy; fn to_portgraph_node(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; @@ -135,8 +135,16 @@ impl HugrInternals for &T { = T::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = T::RegionPortgraph<'p> + where + Self: 'p; + type Node = T::Node; + type RegionPortgraphNodes = T::RegionPortgraphNodes; + hugr_internal_methods! {this, *this} } impl HugrView for &T { @@ -149,8 +157,16 @@ impl HugrInternals for &mut T { = T::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = T::RegionPortgraph<'p> + where + Self: 'p; + type Node = T::Node; + type RegionPortgraphNodes = T::RegionPortgraphNodes; + hugr_internal_methods! {this, &**this} } impl HugrView for &mut T { @@ -169,8 +185,16 @@ impl HugrInternals for Rc { = T::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = T::RegionPortgraph<'p> + where + Self: 'p; + type Node = T::Node; + type RegionPortgraphNodes = T::RegionPortgraphNodes; + hugr_internal_methods! {this, this.as_ref()} } impl HugrView for Rc { @@ -183,8 +207,16 @@ impl HugrInternals for Arc { = T::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = T::RegionPortgraph<'p> + where + Self: 'p; + type Node = T::Node; + type RegionPortgraphNodes = T::RegionPortgraphNodes; + hugr_internal_methods! {this, this.as_ref()} } impl HugrView for Arc { @@ -197,8 +229,16 @@ impl HugrInternals for Box { = T::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = T::RegionPortgraph<'p> + where + Self: 'p; + type Node = T::Node; + type RegionPortgraphNodes = T::RegionPortgraphNodes; + hugr_internal_methods! {this, this.as_ref()} } impl HugrView for Box { @@ -217,8 +257,16 @@ impl HugrInternals for Cow<'_, T> { = T::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = T::RegionPortgraph<'p> + where + Self: 'p; + type Node = T::Node; + type RegionPortgraphNodes = T::RegionPortgraphNodes; + hugr_internal_methods! {this, this.as_ref()} } impl HugrView for Cow<'_, T> { diff --git a/hugr-core/src/hugr/views/rerooted.rs b/hugr-core/src/hugr/views/rerooted.rs index 43323b7439..9fde68f991 100644 --- a/hugr-core/src/hugr/views/rerooted.rs +++ b/hugr-core/src/hugr/views/rerooted.rs @@ -41,8 +41,16 @@ impl HugrInternals for Rerooted { = H::Portgraph<'p> where Self: 'p; + + type RegionPortgraph<'p> + = H::RegionPortgraph<'p> + where + Self: 'p; + type Node = H::Node; + type RegionPortgraphNodes = H::RegionPortgraphNodes; + super::impls::hugr_internal_methods! {this, &this.hugr} } diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 76bb2bb09a..74c3ec0ae1 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -1,4 +1,5 @@ use anyhow::{anyhow, bail, Result}; +use hugr_core::hugr::internal::PortgraphNodeMap; use hugr_core::ops::{ constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant, LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, @@ -70,10 +71,10 @@ where debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len()); debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len()); - let region_graph = node.hugr().region_portgraph(node.node()); + let (region_graph, node_map) = node.hugr().region_portgraph(node.node()); let topo = Topo::new(®ion_graph); for n in topo.iter(®ion_graph) { - let node = node.hugr().fat_optype(node.hugr().from_portgraph_node(n)); + let node = node.hugr().fat_optype(node_map.from_portgraph(n)); let inputs_rmb = context.node_ins_rmb(node)?; let inputs = inputs_rmb.read(context.builder(), [])?; let outputs = context.node_outs_rmb(node)?.promise(); diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index d7fe3135be..88a8f06a69 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -1,6 +1,7 @@ //! Provides [force_order], a tool for fixing the order of nodes in a Hugr. use std::{cmp::Reverse, collections::BinaryHeap, iter}; +use hugr_core::hugr::internal::PortgraphNodeMap; use hugr_core::{ hugr::{hugrmut::HugrMut, HugrError}, ops::{OpTag, OpTrait}, @@ -55,15 +56,15 @@ pub fn force_order_by_key, K: Ord>( // we filter out the input and output nodes from the topological sort let [i, o] = hugr.get_io(dp).unwrap(); let ordered_nodes = { - let i_pg = hugr.to_portgraph_node(i); - let o_pg = hugr.to_portgraph_node(o); - let rank = |n| rank(hugr, hugr.from_portgraph_node(n)); - let region = hugr.region_portgraph(dp); + let (region, node_map) = hugr.region_portgraph(dp); + let rank = |n| rank(hugr, node_map.from_portgraph(n)); + let i_pg = node_map.to_portgraph(i); + let o_pg = node_map.to_portgraph(o); let petgraph = NodeFiltered::from_fn(®ion, |x| x != i_pg && x != o_pg); ForceOrder::<_, portgraph::NodeIndex, _, _>::new(&petgraph, &rank) .iter(&petgraph) .filter_map(|x| { - let x = hugr.from_portgraph_node(x); + let x = node_map.from_portgraph(x); let expected_edge = Some(EdgeKind::StateOrder); let optype = hugr.get_optype(x); if optype.other_input() == expected_edge