Skip to content

Commit 0a37e8c

Browse files
committed
feat!: Return a node mapping in HugrInternals::region_portgraph
1 parent 851e396 commit 0a37e8c

File tree

6 files changed

+138
-16
lines changed

6 files changed

+138
-16
lines changed

hugr-core/src/hugr/internal.rs

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,37 @@ pub trait HugrInternals {
2424
where
2525
Self: 'p;
2626

27+
/// The portgraph graph structure returned by [`HugrInternals::region_portgraph`].
28+
type RegionPortgraph<'p>: LinkView<LinkEndpoint: Eq> + Clone + 'p
29+
where
30+
Self: 'p;
31+
2732
/// The type of nodes in the Hugr.
2833
type Node: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash;
2934

35+
/// A mapping between HUGR nodes and portgraph nodes in the graph returned by
36+
/// [`HugrInternals::region_portgraph`].
37+
type RegionPortgraphNodes: PortgraphNodeMap<Self::Node>;
38+
3039
/// Returns a reference to the underlying portgraph.
3140
fn portgraph(&self) -> Self::Portgraph<'_>;
3241

33-
/// Returns a flat portgraph view of a region in the HUGR.
34-
///
35-
/// This is a subgraph of [`HugrInternals::portgraph`], with a flat hierarchy.
42+
/// Returns a flat portgraph view of a region in the HUGR, and a mapping between
43+
/// HUGR nodes and portgraph nodes in the graph.
44+
//
45+
// NOTE: Ideally here we would just return `Self::RegionPortgraph<'_>`, but
46+
// when doing so we are unable to restrict the type to implement petgraph's
47+
// traits over references (e.g. `&MyGraph : IntoNodeIdentifiers`, which is
48+
// needed if we want to use petgraph's algorithms on the region graph).
49+
// This won't be solvable until we don't do the big petgraph refactor -.-
50+
// In the meantime, just wrap the portgraph in a `FlatRegion` as needed.
3651
fn region_portgraph(
3752
&self,
3853
parent: Self::Node,
39-
) -> portgraph::view::FlatRegion<'_, impl LinkView<LinkEndpoint: Eq> + Clone + '_>;
54+
) -> (
55+
portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>,
56+
Self::RegionPortgraphNodes,
57+
);
4058

4159
/// Returns the portgraph [Hierarchy](portgraph::Hierarchy) of the graph
4260
/// returned by [`HugrInternals::portgraph`].
@@ -65,14 +83,56 @@ pub trait HugrInternals {
6583
fn base_hugr(&self) -> &Hugr;
6684
}
6785

86+
/// A map between hugr nodes and portgraph nodes in the graph returned by
87+
/// [`HugrInternals::region_portgraph`].
88+
pub trait PortgraphNodeMap<N>: Clone + Sized + std::fmt::Debug {
89+
/// Returns the portgraph index of a HUGR node in the associated region
90+
/// graph.
91+
///
92+
/// If the node is not in the region, the result is undefined.
93+
fn to_portgraph(&self, node: N) -> portgraph::NodeIndex;
94+
95+
/// Returns the HUGR node for a portgraph node in the associated region
96+
/// graph.
97+
///
98+
/// If the node is not in the region, the result is undefined.
99+
#[allow(clippy::wrong_self_convention)]
100+
fn from_portgraph(&self, node: portgraph::NodeIndex) -> N;
101+
}
102+
103+
/// An identity map between HUGR nodes and portgraph nodes.
104+
#[derive(
105+
Copy, Clone, Debug, Default, Eq, PartialEq, Hash, PartialOrd, Ord, derive_more::Display,
106+
)]
107+
pub struct DefaultPGNodeMap;
108+
109+
impl PortgraphNodeMap<Node> for DefaultPGNodeMap {
110+
#[inline]
111+
fn to_portgraph(&self, node: Node) -> portgraph::NodeIndex {
112+
node.into_portgraph()
113+
}
114+
115+
#[inline]
116+
fn from_portgraph(&self, node: portgraph::NodeIndex) -> Node {
117+
node.into()
118+
}
119+
}
120+
68121
impl HugrInternals for Hugr {
69122
type Portgraph<'p>
70123
= &'p MultiPortGraph
71124
where
72125
Self: 'p;
73126

127+
type RegionPortgraph<'p>
128+
= &'p MultiPortGraph
129+
where
130+
Self: 'p;
131+
74132
type Node = Node;
75133

134+
type RegionPortgraphNodes = DefaultPGNodeMap;
135+
76136
#[inline]
77137
fn portgraph(&self) -> Self::Portgraph<'_> {
78138
&self.graph
@@ -82,10 +142,14 @@ impl HugrInternals for Hugr {
82142
fn region_portgraph(
83143
&self,
84144
parent: Self::Node,
85-
) -> portgraph::view::FlatRegion<'_, impl LinkView<LinkEndpoint: Eq> + Clone + '_> {
145+
) -> (
146+
portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>,
147+
Self::RegionPortgraphNodes,
148+
) {
86149
let pg = self.portgraph();
87150
let root = self.to_portgraph_node(parent);
88-
portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root)
151+
let region = portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root);
152+
(region, DefaultPGNodeMap)
89153
}
90154

91155
#[inline]

hugr-core/src/hugr/validate.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl<'a> ValidationContext<'a> {
8686
/// The results of this computation should be cached in `self.dominators`.
8787
/// We don't do it here to avoid mutable borrows.
8888
fn compute_dominator(&self, parent: Node) -> Dominators<portgraph::NodeIndex> {
89-
let region = self.hugr.region_portgraph(parent);
89+
let (region, _) = self.hugr.region_portgraph(parent);
9090
let entry_node = self.hugr.children(parent).next().unwrap();
9191
dominators::simple_fast(&region, entry_node.into_portgraph())
9292
}
@@ -357,7 +357,7 @@ impl<'a> ValidationContext<'a> {
357357
return Ok(());
358358
};
359359

360-
let region = self.hugr.region_portgraph(parent);
360+
let (region, _) = self.hugr.region_portgraph(parent);
361361
let postorder = Topo::new(&region);
362362
let nodes_visited = postorder
363363
.iter(&region)

hugr-core/src/hugr/views/impls.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ macro_rules! hugr_internal_methods {
1212
delegate::delegate! {
1313
to ({let $arg=self; $e}) {
1414
fn portgraph(&self) -> Self::Portgraph<'_>;
15-
fn region_portgraph(&self, parent: Self::Node) -> portgraph::view::FlatRegion<'_, impl portgraph::view::LinkView<LinkEndpoint: Eq> + Clone + '_>;
15+
fn region_portgraph(&self, parent: Self::Node) -> (portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>, Self::RegionPortgraphNodes);
1616
fn hierarchy(&self) -> &portgraph::Hierarchy;
1717
fn to_portgraph_node(&self, node: impl crate::ops::handle::NodeHandle<Self::Node>) -> portgraph::NodeIndex;
1818
fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node;
@@ -135,8 +135,16 @@ impl<T: HugrView> HugrInternals for &T {
135135
= T::Portgraph<'p>
136136
where
137137
Self: 'p;
138+
139+
type RegionPortgraph<'p>
140+
= T::RegionPortgraph<'p>
141+
where
142+
Self: 'p;
143+
138144
type Node = T::Node;
139145

146+
type RegionPortgraphNodes = T::RegionPortgraphNodes;
147+
140148
hugr_internal_methods! {this, *this}
141149
}
142150
impl<T: HugrView> HugrView for &T {
@@ -149,8 +157,16 @@ impl<T: HugrView> HugrInternals for &mut T {
149157
= T::Portgraph<'p>
150158
where
151159
Self: 'p;
160+
161+
type RegionPortgraph<'p>
162+
= T::RegionPortgraph<'p>
163+
where
164+
Self: 'p;
165+
152166
type Node = T::Node;
153167

168+
type RegionPortgraphNodes = T::RegionPortgraphNodes;
169+
154170
hugr_internal_methods! {this, &**this}
155171
}
156172
impl<T: HugrView> HugrView for &mut T {
@@ -169,8 +185,16 @@ impl<T: HugrView> HugrInternals for Rc<T> {
169185
= T::Portgraph<'p>
170186
where
171187
Self: 'p;
188+
189+
type RegionPortgraph<'p>
190+
= T::RegionPortgraph<'p>
191+
where
192+
Self: 'p;
193+
172194
type Node = T::Node;
173195

196+
type RegionPortgraphNodes = T::RegionPortgraphNodes;
197+
174198
hugr_internal_methods! {this, this.as_ref()}
175199
}
176200
impl<T: HugrView> HugrView for Rc<T> {
@@ -183,8 +207,16 @@ impl<T: HugrView> HugrInternals for Arc<T> {
183207
= T::Portgraph<'p>
184208
where
185209
Self: 'p;
210+
211+
type RegionPortgraph<'p>
212+
= T::RegionPortgraph<'p>
213+
where
214+
Self: 'p;
215+
186216
type Node = T::Node;
187217

218+
type RegionPortgraphNodes = T::RegionPortgraphNodes;
219+
188220
hugr_internal_methods! {this, this.as_ref()}
189221
}
190222
impl<T: HugrView> HugrView for Arc<T> {
@@ -197,8 +229,16 @@ impl<T: HugrView> HugrInternals for Box<T> {
197229
= T::Portgraph<'p>
198230
where
199231
Self: 'p;
232+
233+
type RegionPortgraph<'p>
234+
= T::RegionPortgraph<'p>
235+
where
236+
Self: 'p;
237+
200238
type Node = T::Node;
201239

240+
type RegionPortgraphNodes = T::RegionPortgraphNodes;
241+
202242
hugr_internal_methods! {this, this.as_ref()}
203243
}
204244
impl<T: HugrView> HugrView for Box<T> {
@@ -217,8 +257,16 @@ impl<T: HugrView + ToOwned> HugrInternals for Cow<'_, T> {
217257
= T::Portgraph<'p>
218258
where
219259
Self: 'p;
260+
261+
type RegionPortgraph<'p>
262+
= T::RegionPortgraph<'p>
263+
where
264+
Self: 'p;
265+
220266
type Node = T::Node;
221267

268+
type RegionPortgraphNodes = T::RegionPortgraphNodes;
269+
222270
hugr_internal_methods! {this, this.as_ref()}
223271
}
224272
impl<T: HugrView + ToOwned> HugrView for Cow<'_, T> {

hugr-core/src/hugr/views/rerooted.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,16 @@ impl<H: HugrView> HugrInternals for Rerooted<H> {
4141
= H::Portgraph<'p>
4242
where
4343
Self: 'p;
44+
45+
type RegionPortgraph<'p>
46+
= H::RegionPortgraph<'p>
47+
where
48+
Self: 'p;
49+
4450
type Node = H::Node;
4551

52+
type RegionPortgraphNodes = H::RegionPortgraphNodes;
53+
4654
super::impls::hugr_internal_methods! {this, &this.hugr}
4755
}
4856

hugr-llvm/src/emit/ops.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use anyhow::{anyhow, bail, Result};
2+
use hugr_core::hugr::internal::PortgraphNodeMap;
23
use hugr_core::ops::{
34
constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant,
45
LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG,
@@ -70,10 +71,10 @@ where
7071
debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len());
7172
debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len());
7273

73-
let region_graph = node.hugr().region_portgraph(node.node());
74+
let (region_graph, node_map) = node.hugr().region_portgraph(node.node());
7475
let topo = Topo::new(&region_graph);
7576
for n in topo.iter(&region_graph) {
76-
let node = node.hugr().fat_optype(node.hugr().from_portgraph_node(n));
77+
let node = node.hugr().fat_optype(node_map.from_portgraph(n));
7778
let inputs_rmb = context.node_ins_rmb(node)?;
7879
let inputs = inputs_rmb.read(context.builder(), [])?;
7980
let outputs = context.node_outs_rmb(node)?.promise();

hugr-passes/src/force_order.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Provides [force_order], a tool for fixing the order of nodes in a Hugr.
22
use std::{cmp::Reverse, collections::BinaryHeap, iter};
33

4+
use hugr_core::hugr::internal::PortgraphNodeMap;
45
use hugr_core::{
56
hugr::{hugrmut::HugrMut, HugrError},
67
ops::{OpTag, OpTrait},
@@ -55,15 +56,15 @@ pub fn force_order_by_key<H: HugrMut<Node = Node>, K: Ord>(
5556
// we filter out the input and output nodes from the topological sort
5657
let [i, o] = hugr.get_io(dp).unwrap();
5758
let ordered_nodes = {
58-
let i_pg = hugr.to_portgraph_node(i);
59-
let o_pg = hugr.to_portgraph_node(o);
60-
let rank = |n| rank(hugr, hugr.from_portgraph_node(n));
61-
let region = hugr.region_portgraph(dp);
59+
let (region, node_map) = hugr.region_portgraph(dp);
60+
let rank = |n| rank(hugr, node_map.from_portgraph(n));
61+
let i_pg = node_map.to_portgraph(i);
62+
let o_pg = node_map.to_portgraph(o);
6263
let petgraph = NodeFiltered::from_fn(&region, |x| x != i_pg && x != o_pg);
6364
ForceOrder::<_, portgraph::NodeIndex, _, _>::new(&petgraph, &rank)
6465
.iter(&petgraph)
6566
.filter_map(|x| {
66-
let x = hugr.from_portgraph_node(x);
67+
let x = node_map.from_portgraph(x);
6768
let expected_edge = Some(EdgeKind::StateOrder);
6869
let optype = hugr.get_optype(x);
6970
if optype.other_input() == expected_edge

0 commit comments

Comments
 (0)