Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 180 additions & 54 deletions hugr-passes/src/replace_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow};
use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::ops::constant::{OpaqueValue, Sum};
use hugr_core::ops::handle::DataflowOpID;
use hugr_core::ops::handle::{DataflowOpID, FuncID};
use hugr_core::ops::{
AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp,
FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop,
Value, CFG, DFG,
};
use hugr_core::types::{
ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer,
ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow,
TypeTransformer,
};
use hugr_core::{Hugr, HugrView, Node, Wire};
use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire};

use crate::ComposablePass;

Expand All @@ -45,21 +46,37 @@ pub enum NodeTemplate {
/// Note this will be of limited use before [monomorphization](super::monomorphize())
/// because the new subtree will not be able to use type variables present in the
/// parent Hugr or previous op.
// TODO: store also a vec<TypeParam>, and update Hugr::validate to take &[TypeParam]s
// (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709
CompoundOp(Box<Hugr>),
// TODO allow also Call to a Node in the existing Hugr
// (can't see any other way to achieve multiple calls to the same decl.
// So client should add the functions before replacement, then remove unused ones afterwards.)
/// A Call to an existing function.
Call(Node, Vec<TypeArg>),
}

impl NodeTemplate {
/// Adds this instance to the specified [HugrMut] as a new node or subtree under a
/// given parent, returning the unique new child (of that parent) thus created
pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node {
///
/// # Panics
///
/// * If `parent` is not in the `hugr`
///
/// # Errors
///
/// * If `self` is a [Self::Call] and the target Node either
/// * is neither a [FuncDefn] nor a [FuncDecl]
/// * has a [`signature`] which the type-args of the [Self::Call] do not match
///
/// [`signature`]: hugr_core::types::PolyFuncType
pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result<Node, BuildError> {
match self {
NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type),
NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root,
NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)),
NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root),
NodeTemplate::Call(target, type_args) => {
let c = call(hugr, target, type_args)?;
let tgt_port = c.called_function_port();
let n = hugr.add_node_with_parent(parent, c);
hugr.connect(target, 0, n, tgt_port);
Ok(n)
}
}
}

Expand All @@ -72,10 +89,15 @@ impl NodeTemplate {
match self {
NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs),
NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs),
// Really we should check whether func points at a FuncDecl or FuncDefn and create
// the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call.
NodeTemplate::Call(func, type_args) => {
dfb.call(&FuncID::<true>::from(func), &type_args, inputs)
}
}
}

fn replace(&self, hugr: &mut impl HugrMut, n: Node) {
fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> {
assert_eq!(hugr.children(n).count(), 0);
let new_optype = match self.clone() {
NodeTemplate::SingleOp(op_type) => op_type,
Expand All @@ -88,19 +110,57 @@ impl NodeTemplate {
}
root_opty
}
NodeTemplate::Call(func, type_args) => {
let c = call(hugr, func, type_args)?;
let static_inport = c.called_function_port();
// insert an input for the Call static input
hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1);
// connect the function to (what will be) the call
hugr.connect(func, 0, n, static_inport);
c.into()
}
};
*hugr.optype_mut(n) = new_optype;
Ok(())
}

fn signature(&self) -> Option<Cow<'_, Signature>> {
match self {
fn check_signature(
&self,
inputs: &TypeRow,
outputs: &TypeRow,
) -> Result<(), Option<Signature>> {
let sig = match self {
NodeTemplate::SingleOp(op_type) => op_type,
NodeTemplate::CompoundOp(hugr) => hugr.root_type(),
NodeTemplate::Call(_, _) => return Ok(()), // no way to tell
}
.dataflow_signature();
if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) {
Ok(())
} else {
Err(sig.map(Cow::into_owned))
}
.dataflow_signature()
}
}

fn call<H: HugrView<Node = Node>>(
h: &H,
func: Node,
type_args: Vec<TypeArg>,
) -> Result<Call, BuildError> {
let func_sig = match h.get_optype(func) {
OpType::FuncDecl(fd) => fd.signature.clone(),
OpType::FuncDefn(fd) => fd.signature.clone(),
_ => {
return Err(BuildError::UnexpectedType {
node: func,
op_desc: "func defn/decl",
})
}
};
Ok(Call::try_new(func_sig, type_args)?)
}

/// A configuration of what types, ops, and constants should be replaced with what.
/// May be applied to a Hugr via [Self::run].
///
Expand Down Expand Up @@ -186,6 +246,8 @@ pub enum ReplaceTypesError {
ConstError(#[from] ConstTypeError),
#[error(transparent)]
LinearizeError(#[from] LinearizeError),
#[error("Replacement op for {0} could not be added because {1}")]
AddTemplateError(Node, BuildError),
}

impl ReplaceTypes {
Expand Down Expand Up @@ -370,8 +432,11 @@ impl ReplaceTypes {

OpType::Const(Const { value, .. }) => self.change_value(value),
OpType::ExtensionOp(ext_op) => Ok(
// Copy/discard insertion done by caller
if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
replacement.replace(hugr, n); // Copy/discard insertion done by caller
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?;
true
} else {
let def = ext_op.def_arc();
Expand All @@ -382,7 +447,9 @@ impl ReplaceTypes {
.get(&def.as_ref().into())
.and_then(|rep_fn| rep_fn(&args))
{
replacement.replace(hugr, n);
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?;
true
} else {
if ch {
Expand Down Expand Up @@ -515,24 +582,22 @@ mod test {
use std::sync::Arc;

use hugr_core::builder::{
inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder,
inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder,
};
use hugr_core::extension::prelude::{
bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder,
bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID,
};
use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version};
use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::{IdentList, ValidationError};
use hugr_core::ops::{
constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value,
};
use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef;
use hugr_core::ops::constant::OpaqueValue;
use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value};
use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef};
use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use hugr_core::std_extensions::collections::array::{
array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue,
};
use hugr_core::std_extensions::collections::list::{
list_type, list_type_def, ListOp, ListValue,
use hugr_core::std_extensions::collections::{
array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue},
list::{list_type, list_type_def, ListOp, ListValue},
};
use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow};
use hugr_core::{type_row, Extension, HugrView};
Expand Down Expand Up @@ -601,30 +666,37 @@ mod test {
)
}

fn lowerer(ext: &Arc<Extension>) -> ReplaceTypes {
fn lowered_read(args: &[TypeArg]) -> Option<NodeTemplate> {
let ty = just_elem_type(args);
let mut dfb = DFGBuilder::new(inout_sig(
vec![array_type(64, ty.clone()), i64_t()],
ty.clone(),
))
fn lowered_read<T: Container + Dataflow>(
elem_ty: Type,
new: impl Fn(Signature) -> Result<T, BuildError>,
) -> T {
let mut dfb = new(Signature::new(
vec![array_type(64, elem_ty.clone()), i64_t()],
elem_ty.clone(),
)
.with_extension_delta(ExtensionSet::from_iter([
PRELUDE_ID,
array::EXTENSION_ID,
conversions::EXTENSION_ID,
])))
.unwrap();
let [val, idx] = dfb.input_wires_arr();
let [idx] = dfb
.add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx])
.unwrap()
.outputs_arr();
let [opt] = dfb
.add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx])
.unwrap()
.outputs_arr();
let [res] = dfb
.build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt)
.unwrap();
let [val, idx] = dfb.input_wires_arr();
let [idx] = dfb
.add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx])
.unwrap()
.outputs_arr();
let [opt] = dfb
.add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx])
.unwrap()
.outputs_arr();
let [res] = dfb
.build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt)
.unwrap();
Some(NodeTemplate::CompoundOp(Box::new(
dfb.finish_hugr_with_outputs([res]).unwrap(),
)))
}
dfb.set_outputs([res]).unwrap();
dfb
}

fn lowerer(ext: &Arc<Extension>) -> ReplaceTypes {
let pv = ext.get_type(PACKED_VEC).unwrap();
let mut lw = ReplaceTypes::default();
lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t());
Expand All @@ -640,7 +712,13 @@ mod test {
.into(),
),
);
lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read));
lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| {
Some(NodeTemplate::CompoundOp(Box::new(
lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new)
.finish_hugr()
.unwrap(),
)))
});
lw
}

Expand Down Expand Up @@ -977,4 +1055,52 @@ mod test {
repl.run(&mut h).unwrap();
h.validate_no_extensions().unwrap();
}

#[test]
fn op_to_call() {
let e = ext();
let pv = e.get_type(PACKED_VEC).unwrap();
let inner = pv.instantiate([usize_t().into()]).unwrap();
let outer = pv
.instantiate([Type::new_extension(inner.clone()).into()])
.unwrap();
let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap();
let [outer, idx] = dfb.input_wires_arr();
let [inner] = dfb
.add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx])
.unwrap()
.outputs_arr();
let res = dfb
.add_dataflow_op(read_op(&e, usize_t()), [inner, idx])
.unwrap();
let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap();
let read_func = h
.insert_hugr(
h.root(),
lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| {
FunctionBuilder::new(
"lowered_read",
PolyFuncType::new([TypeBound::Copyable.into()], sig),
)
})
.finish_hugr()
.unwrap(),
)
.new_root;

let mut lw = lowerer(&e);
lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| {
Some(NodeTemplate::Call(read_func, args.to_owned()))
});
lw.run(&mut h).unwrap();

assert_eq!(h.output_neighbours(read_func).count(), 2);
let ext_op_names = h
.nodes()
.filter_map(|n| h.get_optype(n).as_extension_op())
.map(|e| e.def().name())
.sorted()
.collect_vec();
assert_eq!(ext_op_names, ["get", "itousize", "panic",]);
}
}
4 changes: 2 additions & 2 deletions hugr-passes/src/replace_types/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub fn linearize_array(
let [to_discard] = dfb.input_wires_arr();
lin.copy_discard_op(ty, 0)?
.add(&mut dfb, [to_discard])
.unwrap();
.map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?;
let ret = dfb.add_load_value(Value::unary_unit_sum());
dfb.finish_hugr_with_outputs([ret]).unwrap()
};
Expand Down Expand Up @@ -162,7 +162,7 @@ pub fn linearize_array(
let mut copies = lin
.copy_discard_op(ty, num_outports)?
.add(&mut dfb, [elem])
.unwrap()
.map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?
.outputs();
let copy0 = copies.next().unwrap(); // We'll return this directly

Expand Down
Loading
Loading