Skip to content

Commit

Permalink
Decompose pass (Qiskit#1487)
Browse files Browse the repository at this point in the history
* clean unroller

* docstring

* moving conditianl handeling in side substitute_circuit_one

* rename the method to make it more descriptive

* initial commit

* linting
  • Loading branch information
1ucian0 authored and ajavadia committed Dec 14, 2018
1 parent fccaf18 commit 6a24148
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 59 deletions.
53 changes: 36 additions & 17 deletions qiskit/dagcircuit/_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def substitute_circuit_all(self, op, input_circuit, wires=None):
Raises:
DAGCircuitError: if met with unexpected predecessor/successors
"""
# TODO: rewrite this method to call substitute_circuit_one
# TODO: rewrite this method to call substitute_node_with_dag
wires = wires or []
if op.name not in self.basis:
raise DAGCircuitError("%s is not in the list of basis operations"
Expand Down Expand Up @@ -1095,37 +1095,58 @@ def substitute_circuit_all(self, op, input_circuit, wires=None):
self.multi_graph.remove_edge(
p[0], self.output_map[w])

def substitute_circuit_one(self, node, input_circuit, wires=None):
"""Replace one node with input_circuit.
def substitute_node_with_dag(self, node, input_dag, wires=None):
"""Replace one node with dag.
Args:
node (int): node of self.multi_graph (of type "op") to substitute
input_circuit (DAGCircuit): circuit that will substitute the node
input_dag (DAGCircuit): circuit that will substitute the node
wires (list[(Register, index)]): gives an order for (qu)bits
in the input circuit. This order gets matched to the node wires
by qargs first, then cargs, then conditions.
Raises:
DAGCircuitError: if met with unexpected predecessor/successors
"""
wires = wires or []
nd = self.multi_graph.node[node]

self._check_wires_list(wires, nd["op"], input_circuit, nd["condition"])
union_basis = self._make_union_basis(input_circuit)
union_gates = self._make_union_gates(input_circuit)
condition = nd["condition"]
# the decomposition rule must be amended if used in a
# conditional context. delete the op nodes and replay
# them with the condition.
if condition:
input_dag.add_creg(condition[0])
to_replay = []
for n_it in nx.topological_sort(input_dag.multi_graph):
n = input_dag.multi_graph.nodes[n_it]
if n["type"] == "op":
n["op"].control = condition
to_replay.append(n)
for n in input_dag.get_op_nodes():
input_dag._remove_op_node(n)
for n in to_replay:
input_dag.apply_operation_back(n["op"], condition=condition)

if wires is None:
qwires = [w for w in input_dag.wires if isinstance(w[0], QuantumRegister)]
cwires = [w for w in input_dag.wires if isinstance(w[0], ClassicalRegister)]
wires = qwires + cwires

self._check_wires_list(wires, nd["op"], input_dag, nd["condition"])
union_basis = self._make_union_basis(input_dag)
union_gates = self._make_union_gates(input_dag)

# Create a proxy wire_map to identify fragments and duplicates
# and determine what registers need to be added to self
proxy_map = {w: QuantumRegister(1, 'proxy') for w in wires}
add_qregs = self._check_edgemap_registers(proxy_map,
input_circuit.qregs,
input_dag.qregs,
{}, False)
for qreg in add_qregs:
self.add_qreg(qreg)

add_cregs = self._check_edgemap_registers(proxy_map,
input_circuit.cregs,
input_dag.cregs,
{}, False)
for creg in add_cregs:
self.add_creg(creg)
Expand All @@ -1148,17 +1169,15 @@ def substitute_circuit_one(self, node, input_circuit, wires=None):
nd["cargs"],
condition_bit_list]
for i in s])}
self._check_wiremap_validity(wire_map, wires,
self.input_map)
self._check_wiremap_validity(wire_map, wires, self.input_map)
pred_map, succ_map = self._make_pred_succ_maps(node)
full_pred_map, full_succ_map = \
self._full_pred_succ_maps(pred_map, succ_map,
input_circuit, wire_map)
full_pred_map, full_succ_map = self._full_pred_succ_maps(pred_map, succ_map,
input_dag, wire_map)
# Now that we know the connections, delete node
self.multi_graph.remove_node(node)
# Iterate over nodes of input_circuit
for m in nx.topological_sort(input_circuit.multi_graph):
md = input_circuit.multi_graph.node[m]
for m in nx.topological_sort(input_dag.multi_graph):
md = input_dag.multi_graph.node[m]
if md["type"] == "op":
# Insert a new node
condition = self._map_condition(wire_map, md["condition"])
Expand Down
6 changes: 3 additions & 3 deletions qiskit/mapper/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ def direction_mapper(circuit_graph, coupling_graph):
cxedge[1][0], cxedge[1][1])
continue
elif (cxedge[1], cxedge[0]) in cg_edges:
circuit_graph.substitute_circuit_one(cx_node,
flipped_cx_circuit,
wires=[qr_fcx[0], qr_fcx[1]])
circuit_graph.substitute_node_with_dag(cx_node,
flipped_cx_circuit,
wires=[qr_fcx[0], qr_fcx[1]])
logger.debug("cx %s[%d], %s[%d] -FLIP",
cxedge[0][0], cxedge[0][1],
cxedge[1][0], cxedge[1][1])
Expand Down
1 change: 1 addition & 0 deletions qiskit/transpiler/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .cx_cancellation import CXCancellation
from .fixed_point import FixedPoint
from .decompose import Decompose
from .mapping.check_map import CheckMap
from .mapping.basic_mapper import BasicMapper
from .mapping.direction_mapper import DirectionMapper
Expand Down
43 changes: 43 additions & 0 deletions qiskit/transpiler/passes/decompose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-

# Copyright 2018, IBM.
#
# This source code is licensed under the Apache License, Version 2.0 found in
# the LICENSE.txt file in the root directory of this source tree.

"""Pass for decompose a gate in a circuit."""

from qiskit.transpiler._basepasses import TransformationPass


class Decompose(TransformationPass):
"""
Expand a gate in a circle using its decomposition rules.
"""

def __init__(self, gate=None):
"""
Args:
gate (Gate): Gate to decompose.
"""
super().__init__()
self.gate = gate

def run(self, dag):
"""Expand a given gate into its decomposition.
Args:
dag(DAGCircuit): input dag
Returns:
DAGCircuit: output dag where gate was expanded.
"""
# Walk through the DAG and expand each non-basis node
for node in dag.get_op_nodes(self.gate):
current_node = dag.multi_graph.node[node]

decomposition_rules = current_node["op"].decompositions()

# TODO: allow choosing other possible decompositions
decomposition_dag = decomposition_rules[0]

dag.substitute_node_with_dag(node, decomposition_dag)
return dag
36 changes: 2 additions & 34 deletions qiskit/transpiler/passes/mapping/unroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

"""Pass for unrolling a circuit to a given basis."""

import networkx as nx

from qiskit.circuit import QuantumRegister, ClassicalRegister
from qiskit.transpiler._basepasses import TransformationPass


Expand All @@ -22,7 +19,7 @@ class Unroller(TransformationPass):
def __init__(self, basis=None):
"""
Args:
basis (list[Instruction]): target basis gates to unroll to
basis (list[Gate]): Target basis gates to unroll to.
"""
super().__init__()
self.basis = basis or []
Expand All @@ -39,9 +36,6 @@ def run(self, dag):
Returns:
DAGCircuit: output unrolled dag
Raises:
TranspilerError: if no decomposition rule is found for an op
"""
# Walk through the DAG and expand each non-basis node
for node in dag.get_gate_nodes():
Expand All @@ -55,31 +49,5 @@ def run(self, dag):
# TODO: allow choosing other possible decompositions
decomposition_dag = self.run(decomposition_rules[0]) # recursively unroll gates

condition = current_node["condition"]
# the decomposition rule must be amended if used in a
# conditional context. delete the op nodes and replay
# them with the condition.
if condition:
decomposition_dag.add_creg(condition[0])
to_replay = []
for n_it in nx.topological_sort(decomposition_dag.multi_graph):
n = decomposition_dag.multi_graph.nodes[n_it]
if n["type"] == "op":
n["op"].control = condition
to_replay.append(n)
for n in decomposition_dag.get_op_nodes():
decomposition_dag._remove_op_node(n)
for n in to_replay:
decomposition_dag.apply_operation_back(n["op"], condition=condition)

# the wires for substitute_circuit_one are expected as qargs first,
# then cargs, then conditions
qwires = [w for w in decomposition_dag.wires
if isinstance(w[0], QuantumRegister)]
cwires = [w for w in decomposition_dag.wires
if isinstance(w[0], ClassicalRegister)]

dag.substitute_circuit_one(node,
decomposition_dag,
qwires + cwires)
dag.substitute_node_with_dag(node, decomposition_dag)
return dag
9 changes: 4 additions & 5 deletions test/python/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def setUp(self):
self.dag.apply_operation_back(XGate(self.qubit1))

def test_substitute_circuit_one_middle(self):
"""The method substitute_circuit_one() replaces a in-the-middle node with a DAG."""
"""The method substitute_node_with_dag() replaces a in-the-middle node with a DAG."""
cx_node = self.dag.get_op_nodes(op=CnotGate(self.qubit0, self.qubit1)).pop()

flipped_cx_circuit = DAGCircuit()
Expand All @@ -390,17 +390,16 @@ def test_substitute_circuit_one_middle(self):
flipped_cx_circuit.apply_operation_back(HGate(v[0]))
flipped_cx_circuit.apply_operation_back(HGate(v[1]))

self.dag.substitute_circuit_one(cx_node, input_circuit=flipped_cx_circuit,
wires=[v[0], v[1]])
self.dag.substitute_node_with_dag(cx_node, flipped_cx_circuit, wires=[v[0], v[1]])

self.assertEqual(self.dag.count_ops()['h'], 5)

def test_substitute_circuit_one_front(self):
"""The method substitute_circuit_one() replaces a leaf-in-the-front node with a DAG."""
"""The method substitute_node_with_dag() replaces a leaf-in-the-front node with a DAG."""
pass

def test_substitute_circuit_one_back(self):
"""The method substitute_circuit_one() replaces a leaf-in-the-back node with a DAG."""
"""The method substitute_node_with_dag() replaces a leaf-in-the-back node with a DAG."""
pass


Expand Down
85 changes: 85 additions & 0 deletions test/python/transpiler/test_decompose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-

# Copyright 2018, IBM.
#
# This source code is licensed under the Apache License, Version 2.0 found in
# the LICENSE.txt file in the root directory of this source tree.

"""Test the decompose pass"""

from sympy import pi

from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
from qiskit.transpiler.passes import Decompose
from qiskit.converters import circuit_to_dag
from qiskit.extensions.standard import HGate
from qiskit.extensions.standard import ToffoliGate
from ..common import QiskitTestCase


class TestDecompose(QiskitTestCase):
"""Tests the decompose pass."""

def test_basic(self):
"""Test decompose a single H into u2.
"""
qr = QuantumRegister(1, 'qr')
circuit = QuantumCircuit(qr)
circuit.h(qr[0])
dag = circuit_to_dag(circuit)
pass_ = Decompose(HGate(qr[0]))
after_dag = pass_.run(dag)
op_nodes = after_dag.get_op_nodes(data=True)
self.assertEqual(len(op_nodes), 1)
self.assertEqual(op_nodes[0][1]["op"].name, 'u2')

def test_decompose_only_h(self):
"""Test to decompose a single H, without the rest
"""
qr = QuantumRegister(2, 'qr')
circuit = QuantumCircuit(qr)
circuit.h(qr[0])
circuit.cx(qr[0], qr[1])
dag = circuit_to_dag(circuit)
pass_ = Decompose(HGate(qr[0]))
after_dag = pass_.run(dag)
op_nodes = after_dag.get_op_nodes(data=True)
self.assertEqual(len(op_nodes), 2)
for node in op_nodes:
op = node[1]["op"]
self.assertIn(op.name, ['cx', 'u2'])

def test_decompose_toffoli(self):
"""Test decompose CCX.
"""
qr1 = QuantumRegister(2, 'qr1')
qr2 = QuantumRegister(1, 'qr2')
circuit = QuantumCircuit(qr1, qr2)
circuit.ccx(qr1[0], qr1[1], qr2[0])
dag = circuit_to_dag(circuit)
pass_ = Decompose(ToffoliGate(qr1[0], qr1[1], qr2[0]))
after_dag = pass_.run(dag)
op_nodes = after_dag.get_op_nodes(data=True)
self.assertEqual(len(op_nodes), 15)
for node in op_nodes:
op = node[1]["op"]
self.assertIn(op.name, ['h', 't', 'tdg', 'cx'])

def test_decompose_conditional(self):
"""Test decompose a 1-qubit gates with a conditional.
"""
qr = QuantumRegister(1, 'qr')
cr = ClassicalRegister(1, 'cr')
circuit = QuantumCircuit(qr, cr)
circuit.h(qr).c_if(cr, 1)
circuit.x(qr).c_if(cr, 1)
dag = circuit_to_dag(circuit)
pass_ = Decompose(HGate(qr[0]))
after_dag = pass_.run(dag)

ref_circuit = QuantumCircuit(qr, cr)
ref_circuit.u2(0, pi, qr[0]).c_if(cr, 1)
ref_circuit.x(qr).c_if(cr, 1)
ref_dag = circuit_to_dag(ref_circuit)

self.assertEqual(after_dag, ref_dag)

0 comments on commit 6a24148

Please sign in to comment.