Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions hugr-core/src/hugr/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,37 @@ pub trait HugrInternals {
where
Self: 'p;

/// The portgraph graph structure returned by [`HugrInternals::region_portgraph`].
type RegionPortgraph<'p>: LinkView<LinkEndpoint: Eq> + Clone + 'p
where
Self: 'p;

/// The type of nodes in the Hugr.
type Node: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash;

/// A mapping between HUGR nodes and portgraph nodes in the graph returned by
/// [`HugrInternals::region_portgraph`].
type RegionPortgraphNodes: PortgraphNodeMap<Self::Node>;

/// Returns a reference to the underlying portgraph.
fn portgraph(&self) -> Self::Portgraph<'_>;

/// Returns a flat portgraph view of a region in the HUGR.
///
/// This is a subgraph of [`HugrInternals::portgraph`], with a flat hierarchy.
/// Returns a flat portgraph view of a region in the HUGR, and a mapping between
/// HUGR nodes and portgraph nodes in the graph.
//
// NOTE: Ideally here we would just return `Self::RegionPortgraph<'_>`, but
// when doing so we are unable to restrict the type to implement petgraph's
// traits over references (e.g. `&MyGraph : IntoNodeIdentifiers`, which is
// needed if we want to use petgraph's algorithms on the region graph).
// This won't be solvable until we do the big petgraph refactor -.-
// In the meantime, just wrap the portgraph in a `FlatRegion` as needed.
fn region_portgraph(
&self,
parent: Self::Node,
) -> portgraph::view::FlatRegion<'_, impl LinkView<LinkEndpoint: Eq> + Clone + '_>;
) -> (
portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>,
Self::RegionPortgraphNodes,
);

/// Returns the portgraph [Hierarchy](portgraph::Hierarchy) of the graph
/// returned by [`HugrInternals::portgraph`].
Expand Down Expand Up @@ -65,14 +83,56 @@ pub trait HugrInternals {
fn base_hugr(&self) -> &Hugr;
}

/// A map between hugr nodes and portgraph nodes in the graph returned by
/// [`HugrInternals::region_portgraph`].
pub trait PortgraphNodeMap<N>: Clone + Sized + std::fmt::Debug {
/// Returns the portgraph index of a HUGR node in the associated region
/// graph.
///
/// If the node is not in the region, the result is undefined.
fn to_portgraph(&self, node: N) -> portgraph::NodeIndex;

/// Returns the HUGR node for a portgraph node in the associated region
/// graph.
///
/// If the node is not in the region, the result is undefined.
#[allow(clippy::wrong_self_convention)]
fn from_portgraph(&self, node: portgraph::NodeIndex) -> N;
}

/// An identity map between HUGR nodes and portgraph nodes.
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, Hash, PartialOrd, Ord, derive_more::Display,
)]
pub struct DefaultPGNodeMap;

impl PortgraphNodeMap<Node> for DefaultPGNodeMap {
#[inline]
fn to_portgraph(&self, node: Node) -> portgraph::NodeIndex {
node.into_portgraph()
}

#[inline]
fn from_portgraph(&self, node: portgraph::NodeIndex) -> Node {
node.into()
}
}

impl HugrInternals for Hugr {
type Portgraph<'p>
= &'p MultiPortGraph
where
Self: 'p;

type RegionPortgraph<'p>
= &'p MultiPortGraph
where
Self: 'p;

type Node = Node;

type RegionPortgraphNodes = DefaultPGNodeMap;

#[inline]
fn portgraph(&self) -> Self::Portgraph<'_> {
&self.graph
Expand All @@ -82,10 +142,14 @@ impl HugrInternals for Hugr {
fn region_portgraph(
&self,
parent: Self::Node,
) -> portgraph::view::FlatRegion<'_, impl LinkView<LinkEndpoint: Eq> + Clone + '_> {
) -> (
portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>,
Self::RegionPortgraphNodes,
) {
let pg = self.portgraph();
let root = self.to_portgraph_node(parent);
portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root)
let region = portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root);
(region, DefaultPGNodeMap)
}

#[inline]
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl<'a> ValidationContext<'a> {
/// The results of this computation should be cached in `self.dominators`.
/// We don't do it here to avoid mutable borrows.
fn compute_dominator(&self, parent: Node) -> Dominators<portgraph::NodeIndex> {
let region = self.hugr.region_portgraph(parent);
let (region, _) = self.hugr.region_portgraph(parent);
let entry_node = self.hugr.children(parent).next().unwrap();
dominators::simple_fast(&region, entry_node.into_portgraph())
}
Expand Down Expand Up @@ -357,7 +357,7 @@ impl<'a> ValidationContext<'a> {
return Ok(());
};

let region = self.hugr.region_portgraph(parent);
let (region, _) = self.hugr.region_portgraph(parent);
let postorder = Topo::new(&region);
let nodes_visited = postorder
.iter(&region)
Expand Down
50 changes: 49 additions & 1 deletion hugr-core/src/hugr/views/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ macro_rules! hugr_internal_methods {
delegate::delegate! {
to ({let $arg=self; $e}) {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn region_portgraph(&self, parent: Self::Node) -> portgraph::view::FlatRegion<'_, impl portgraph::view::LinkView<LinkEndpoint: Eq> + Clone + '_>;
fn region_portgraph(&self, parent: Self::Node) -> (portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>, Self::RegionPortgraphNodes);
fn hierarchy(&self) -> &portgraph::Hierarchy;
fn to_portgraph_node(&self, node: impl crate::ops::handle::NodeHandle<Self::Node>) -> portgraph::NodeIndex;
fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node;
Expand Down Expand Up @@ -135,8 +135,16 @@ impl<T: HugrView> HugrInternals for &T {
= T::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= T::RegionPortgraph<'p>
where
Self: 'p;

type Node = T::Node;

type RegionPortgraphNodes = T::RegionPortgraphNodes;

hugr_internal_methods! {this, *this}
}
impl<T: HugrView> HugrView for &T {
Expand All @@ -149,8 +157,16 @@ impl<T: HugrView> HugrInternals for &mut T {
= T::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= T::RegionPortgraph<'p>
where
Self: 'p;

type Node = T::Node;

type RegionPortgraphNodes = T::RegionPortgraphNodes;

hugr_internal_methods! {this, &**this}
}
impl<T: HugrView> HugrView for &mut T {
Expand All @@ -169,8 +185,16 @@ impl<T: HugrView> HugrInternals for Rc<T> {
= T::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= T::RegionPortgraph<'p>
where
Self: 'p;

type Node = T::Node;

type RegionPortgraphNodes = T::RegionPortgraphNodes;

hugr_internal_methods! {this, this.as_ref()}
}
impl<T: HugrView> HugrView for Rc<T> {
Expand All @@ -183,8 +207,16 @@ impl<T: HugrView> HugrInternals for Arc<T> {
= T::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= T::RegionPortgraph<'p>
where
Self: 'p;

type Node = T::Node;

type RegionPortgraphNodes = T::RegionPortgraphNodes;

hugr_internal_methods! {this, this.as_ref()}
}
impl<T: HugrView> HugrView for Arc<T> {
Expand All @@ -197,8 +229,16 @@ impl<T: HugrView> HugrInternals for Box<T> {
= T::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= T::RegionPortgraph<'p>
where
Self: 'p;

type Node = T::Node;

type RegionPortgraphNodes = T::RegionPortgraphNodes;

hugr_internal_methods! {this, this.as_ref()}
}
impl<T: HugrView> HugrView for Box<T> {
Expand All @@ -217,8 +257,16 @@ impl<T: HugrView + ToOwned> HugrInternals for Cow<'_, T> {
= T::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= T::RegionPortgraph<'p>
where
Self: 'p;

type Node = T::Node;

type RegionPortgraphNodes = T::RegionPortgraphNodes;

hugr_internal_methods! {this, this.as_ref()}
}
impl<T: HugrView + ToOwned> HugrView for Cow<'_, T> {
Expand Down
8 changes: 8 additions & 0 deletions hugr-core/src/hugr/views/rerooted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,16 @@ impl<H: HugrView> HugrInternals for Rerooted<H> {
= H::Portgraph<'p>
where
Self: 'p;

type RegionPortgraph<'p>
= H::RegionPortgraph<'p>
where
Self: 'p;

type Node = H::Node;

type RegionPortgraphNodes = H::RegionPortgraphNodes;

super::impls::hugr_internal_methods! {this, &this.hugr}
}

Expand Down
5 changes: 3 additions & 2 deletions hugr-llvm/src/emit/ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::{anyhow, bail, Result};
use hugr_core::hugr::internal::PortgraphNodeMap;
use hugr_core::ops::{
constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant,
LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG,
Expand Down Expand Up @@ -70,10 +71,10 @@ where
debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len());
debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len());

let region_graph = node.hugr().region_portgraph(node.node());
let (region_graph, node_map) = node.hugr().region_portgraph(node.node());
let topo = Topo::new(&region_graph);
for n in topo.iter(&region_graph) {
let node = node.hugr().fat_optype(node.hugr().from_portgraph_node(n));
let node = node.hugr().fat_optype(node_map.from_portgraph(n));
let inputs_rmb = context.node_ins_rmb(node)?;
let inputs = inputs_rmb.read(context.builder(), [])?;
let outputs = context.node_outs_rmb(node)?.promise();
Expand Down
11 changes: 6 additions & 5 deletions hugr-passes/src/force_order.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Provides [force_order], a tool for fixing the order of nodes in a Hugr.
use std::{cmp::Reverse, collections::BinaryHeap, iter};

use hugr_core::hugr::internal::PortgraphNodeMap;
use hugr_core::{
hugr::{hugrmut::HugrMut, HugrError},
ops::{OpTag, OpTrait},
Expand Down Expand Up @@ -55,15 +56,15 @@ pub fn force_order_by_key<H: HugrMut<Node = Node>, K: Ord>(
// we filter out the input and output nodes from the topological sort
let [i, o] = hugr.get_io(dp).unwrap();
let ordered_nodes = {
let i_pg = hugr.to_portgraph_node(i);
let o_pg = hugr.to_portgraph_node(o);
let rank = |n| rank(hugr, hugr.from_portgraph_node(n));
let region = hugr.region_portgraph(dp);
let (region, node_map) = hugr.region_portgraph(dp);
let rank = |n| rank(hugr, node_map.from_portgraph(n));
let i_pg = node_map.to_portgraph(i);
let o_pg = node_map.to_portgraph(o);
let petgraph = NodeFiltered::from_fn(&region, |x| x != i_pg && x != o_pg);
ForceOrder::<_, portgraph::NodeIndex, _, _>::new(&petgraph, &rank)
.iter(&petgraph)
.filter_map(|x| {
let x = hugr.from_portgraph_node(x);
let x = node_map.from_portgraph(x);
let expected_edge = Some(EdgeKind::StateOrder);
let optype = hugr.get_optype(x);
if optype.other_input() == expected_edge
Expand Down
Loading