Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
from . import op as op, wire as wire, noise as noise, qubit as qubit
from .groups import wired as wired, kernel as kernel

try:
# NOTE: make sure optional cirq dependency is installed
import cirq as cirq_package # noqa: F401
except ImportError:
pass

Check warning on line 8 in src/bloqade/squin/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/__init__.py#L7-L8

Added lines #L7 - L8 were not covered by tests
else:
from . import cirq as cirq
from .cirq import load_circuit as load_circuit
52 changes: 52 additions & 0 deletions src/bloqade/squin/cirq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any

import cirq
from kirin import ir, types
from kirin.dialects import func

from . import lowering as lowering
from .. import kernel
from .lowering import Squin


def load_circuit(
circuit: cirq.Circuit,
kernel_name: str = "main",
dialects: ir.DialectGroup = kernel,
globals: dict[str, Any] | None = None,
file: str | None = None,
lineno_offset: int = 0,
col_offset: int = 0,
compactify: bool = True,
):

target = Squin(dialects=dialects, circuit=circuit)
body = target.run(
circuit,
source=str(circuit), # TODO: proper source string
file=file,
globals=globals,
lineno_offset=lineno_offset,
col_offset=col_offset,
compactify=compactify,
)

# NOTE: no return value
return_value = func.ConstantNone()
body.blocks[0].stmts.append(return_value)
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))

code = func.Function(
sym_name=kernel_name,
signature=func.Signature((), types.NoneType),
body=body,
)

return ir.Method(
mod=None,
py_func=None,
sym_name=kernel_name,
arg_names=[],
dialects=dialects,
code=code,
)
287 changes: 287 additions & 0 deletions src/bloqade/squin/cirq/lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import math
from typing import Any
from dataclasses import field, dataclass

import cirq
from kirin import ir, lowering
from kirin.rewrite import Walk, CFGCompactify
from kirin.dialects import py, ilist

from .. import op, noise, qubit

CirqNode = cirq.Circuit | cirq.Moment | cirq.Gate | cirq.Qid | cirq.Operation

DecomposeNode = (
cirq.SwapPowGate
| cirq.ISwapPowGate
| cirq.PhasedXPowGate
| cirq.PhasedXZGate
| cirq.CSwapGate
)


@dataclass
class Squin(lowering.LoweringABC[CirqNode]):
"""Lower a cirq.Circuit object to a squin kernel"""

circuit: cirq.Circuit
qreg: qubit.New = field(init=False)
qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
next_qreg_index: int = field(init=False, default=0)

def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
index = self.qreg_index.get(qid)

if index is None:
index = self.next_qreg_index
self.qreg_index[qid] = index
self.next_qreg_index += 1

index_ssa = state.current_frame.push(py.Constant(index)).result
qbit_getitem = state.current_frame.push(py.GetItem(self.qreg.result, index_ssa))
return qbit_getitem.result

def lower_qubit_getindices(
self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
):
qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
qbits_stmt = ilist.New(values=qbits_getitem)
qbits_result = state.current_frame.get(qbits_stmt.name)

if qbits_result is not None:
return qbits_result

Check warning on line 52 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L52

Added line #L52 was not covered by tests

state.current_frame.push(qbits_stmt)
return qbits_stmt.result

def run(
self,
stmt: CirqNode,
*,
source: str | None = None,
globals: dict[str, Any] | None = None,
file: str | None = None,
lineno_offset: int = 0,
col_offset: int = 0,
compactify: bool = True,
) -> ir.Region:

state = lowering.State(
self,
file=file,
lineno_offset=lineno_offset,
col_offset=col_offset,
)

with state.frame(
[stmt],
globals=globals,
finalize_next=False,
) as frame:
# NOTE: create a global register of qubits first
# TODO: can there be a circuit without qubits?
n_qubits = cirq.num_qubits(self.circuit)
n = frame.push(py.Constant(n_qubits))
self.qreg = frame.push(qubit.New(n_qubits=n.result))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might it be better to have the register be an input to the kernel instead of inside of the kernel? Otherwise, the kernel generated cannot be used as a subroutine elsewhere. Likewise, returning the qubit register. See QuEraComputing/bloqade#249.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jon-wurtz see my comment in the RFC issue (let's leave the discussion in on place). Ideally we support both, loading a "standalone" circuit as well as passing the register as an argument, controlled by a kwarg.


self.visit(state, stmt)

if compactify:
Walk(CFGCompactify()).rewrite(frame.curr_region)

region = frame.curr_region

return region

def visit(self, state: lowering.State[CirqNode], node: CirqNode) -> lowering.Result:
name = node.__class__.__name__
return getattr(self, f"visit_{name}", self.generic_visit)(state, node)

def generic_visit(self, state: lowering.State[CirqNode], node: CirqNode):
if isinstance(node, CirqNode):
raise lowering.BuildError(

Check warning on line 102 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L101-L102

Added lines #L101 - L102 were not covered by tests
f"Cannot lower {node.__class__.__name__} node: {node}"
)
raise lowering.BuildError(

Check warning on line 105 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L105

Added line #L105 was not covered by tests
f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node"
)

def lower_literal(self, state: lowering.State[CirqNode], value) -> ir.SSAValue:
raise lowering.BuildError("Literals not supported in cirq circuit")

Check warning on line 110 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L110

Added line #L110 was not covered by tests

def lower_global(
self, state: lowering.State[CirqNode], node: CirqNode
) -> lowering.LoweringABC.Result:
raise lowering.BuildError("Literals not supported in cirq circuit")

Check warning on line 115 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L115

Added line #L115 was not covered by tests

def visit_Circuit(
self, state: lowering.State[CirqNode], node: cirq.Circuit
) -> lowering.Result:
for moment in node:
state.lower(moment)

def visit_Moment(
self, state: lowering.State[CirqNode], node: cirq.Moment
) -> lowering.Result:
for op_ in node.operations:
state.lower(op_)
Comment on lines +123 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in a separate PR, to address @jon-wurtz 's comment, we should:

  • lower moment with same gate into broadcast(op, ...)
  • lower moment with different gate into parallel statement (which will be added on Kirin side)

to facilitate this we may want to just add a Moment statement as an IR to temperarily perserve this highlevel structure coming from cirq and then implement a rewrite from Moment to squin.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or a Moment should be rewrite into something like a Parallel with a region?


def visit_GateOperation(
self, state: lowering.State[CirqNode], node: cirq.GateOperation
):
if isinstance(node.gate, cirq.MeasurementGate):
# NOTE: special dispatch here, since measurement is a gate + a qubit in cirq,
# but a single statement in squin
return self.lower_measurement(state, node)

if isinstance(node.gate, DecomposeNode):
# NOTE: easier to decompose these, but for that we need the qubits too,
# so we need to do this within this method
for subnode in cirq.decompose_once(node):
state.lower(subnode)
return

op_ = state.lower(node.gate).expect_one()
qbits = self.lower_qubit_getindices(state, node.qubits)
return state.current_frame.push(qubit.Apply(operator=op_, qubits=qbits))

def lower_measurement(
self, state: lowering.State[CirqNode], node: cirq.GateOperation
):
if len(node.qubits) == 1:
qbit = self.lower_qubit_getindex(state, node.qubits[0])
return state.current_frame.push(qubit.MeasureQubit(qbit))

qbits = self.lower_qubit_getindices(state, node.qubits)
return state.current_frame.push(qubit.MeasureQubitList(qbits))

def visit_SingleQubitPauliStringGateOperation(
self,
state: lowering.State[CirqNode],
node: cirq.SingleQubitPauliStringGateOperation,
):

match node.pauli:
case cirq.X:
op_ = op.stmts.X()
case cirq.Y:
op_ = op.stmts.Y()
case cirq.Z:
op_ = op.stmts.Z()
case cirq.I:
op_ = op.stmts.Identity(sites=1)
case _:
raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}")

Check warning on line 174 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L171-L174

Added lines #L171 - L174 were not covered by tests

state.current_frame.push(op_)
qargs = self.lower_qubit_getindices(state, [node.qubit])
return state.current_frame.push(qubit.Apply(op_.result, qargs))

def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate):
if node.exponent == 1:
return state.current_frame.push(op.stmts.H())

return state.lower(node.in_su2())

Check warning on line 184 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L184

Added line #L184 was not covered by tests

def visit_XPowGate(self, state: lowering.State[CirqNode], node: cirq.XPowGate):
if node.exponent == 1:
return state.current_frame.push(op.stmts.X())

return self.visit(state, node.in_su2())

def visit_YPowGate(self, state: lowering.State[CirqNode], node: cirq.YPowGate):
if node.exponent == 1:
return state.current_frame.push(op.stmts.Y())

return self.visit(state, node.in_su2())

def visit_ZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZPowGate):
if node.exponent == 0.5:
return state.current_frame.push(op.stmts.S())

if node.exponent == 0.25:
return state.current_frame.push(op.stmts.T())

if node.exponent == 1:
return state.current_frame.push(op.stmts.Z())

# NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp
t = node.exponent
theta = state.current_frame.push(py.Constant(math.pi * t))
return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result))

def visit_Rx(self, state: lowering.State[CirqNode], node: cirq.Rx):
x = state.current_frame.push(op.stmts.X())
angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
return state.current_frame.push(op.stmts.Rot(axis=x.result, angle=angle.result))

def visit_Ry(self, state: lowering.State[CirqNode], node: cirq.Ry):
y = state.current_frame.push(op.stmts.Y())
angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
return state.current_frame.push(op.stmts.Rot(axis=y.result, angle=angle.result))

def visit_Rz(self, state: lowering.State[CirqNode], node: cirq.Rz):
z = state.current_frame.push(op.stmts.Z())
angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent))
return state.current_frame.push(op.stmts.Rot(axis=z.result, angle=angle.result))

def visit_CXPowGate(self, state: lowering.State[CirqNode], node: cirq.CXPowGate):
x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Control(x, n_controls=1))

def visit_CZPowGate(self, state: lowering.State[CirqNode], node: cirq.CZPowGate):
z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Control(z, n_controls=1))

def visit_ControlledOperation(
self, state: lowering.State[CirqNode], node: cirq.ControlledOperation
):
return self.visit_GateOperation(state, node)

def visit_ControlledGate(
self, state: lowering.State[CirqNode], node: cirq.ControlledGate
):
op_ = state.lower(node.sub_gate).expect_one()
n_controls = node.num_controls()
return state.current_frame.push(op.stmts.Control(op_, n_controls=n_controls))

def visit_XXPowGate(self, state: lowering.State[CirqNode], node: cirq.XXPowGate):
x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Kron(x, x))

def visit_YYPowGate(self, state: lowering.State[CirqNode], node: cirq.YYPowGate):
y = state.lower(cirq.YPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Kron(y, y))

def visit_ZZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZZPowGate):
z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Kron(z, z))

def visit_CCXPowGate(self, state: lowering.State[CirqNode], node: cirq.CCXPowGate):
x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Control(x, n_controls=2))

def visit_CCZPowGate(self, state: lowering.State[CirqNode], node: cirq.CCZPowGate):
z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one()
return state.current_frame.push(op.stmts.Control(z, n_controls=2))

def visit_BitFlipChannel(
self, state: lowering.State[CirqNode], node: cirq.BitFlipChannel
):
x = state.current_frame.push(op.stmts.X())
p = state.current_frame.push(py.Constant(node.p))
return state.current_frame.push(

Check warning on line 273 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L271-L273

Added lines #L271 - L273 were not covered by tests
noise.stmts.PauliError(basis=x.result, p=p.result)
)

def visit_AmplitudeDampingChannel(
self, state: lowering.State[CirqNode], node: cirq.AmplitudeDampingChannel
):
r = state.current_frame.push(op.stmts.Reset())
p = state.current_frame.push(py.Constant(node.gamma))

Check warning on line 281 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L280-L281

Added lines #L280 - L281 were not covered by tests

# TODO: do we need a dedicated noise stmt for this? Using PauliError
# with this basis feels like a hack
noise_channel = noise.stmts.PauliError(basis=r.result, p=p.result)

Check warning on line 285 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L285

Added line #L285 was not covered by tests

return noise_channel

Check warning on line 287 in src/bloqade/squin/cirq/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/lowering.py#L287

Added line #L287 was not covered by tests
4 changes: 2 additions & 2 deletions src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from kirin.dialects import ilist
from kirin.rewrite.walk import Walk

from . import op, wire, qubit
from . import op, wire, noise, qubit
from .op.rewrite import PyMultToSquinMult
from .rewrite.measure_desugar import MeasureDesugarRule


@ir.dialect_group(structural_no_opt.union([op, qubit]))
@ir.dialect_group(structural_no_opt.union([op, qubit, noise]))
def kernel(self):
fold_pass = passes.Fold(self)
typeinfer_pass = passes.TypeInfer(self)
Expand Down
Loading