Skip to content

Commit

Permalink
fix: Add a parameter to control the removal of decomposed resources i…
Browse files Browse the repository at this point in the history
…n the aggregation function (#109)
  • Loading branch information
sitong1011 authored Aug 16, 2024
1 parent c08eb2f commit 37acec8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
23 changes: 15 additions & 8 deletions src/bartiq/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
from collections import defaultdict
from typing import Any, Dict, List, Set

from bartiq import Resource, Routine
from bartiq import Resource, ResourceType, Routine
from bartiq.symbolics import sympy_backend
from bartiq.verification import verify_uncompiled_routine

BACKEND = sympy_backend


def add_aggregated_resources(routine: Routine, aggregation_dict: Dict[str, Dict[str, Any]], backend=BACKEND) -> Routine:
def add_aggregated_resources(
routine: Routine, aggregation_dict: Dict[str, Dict[str, Any]], remove_decomposed: bool = True, backend=BACKEND
) -> Routine:
"""Add aggregated resources to bartiq routine based on the aggregation dictionary.
Args:
Expand All @@ -37,6 +39,10 @@ def add_aggregated_resources(routine: Routine, aggregation_dict: Dict[str, Dict[
"arbitrary_z": {"T_gates": "3*log2(1/epsilon) + O(log(log(1/epsilon)))"},
...
}
remove_decomposed : Whether to remove the decomposed resources from the routine.
Defaults to True.
backend : Backend instance to use for handling expressions.
Defaults to `sympy_backend`.
Returns:
Routine: The program with aggregated resources.
Expand All @@ -46,13 +52,14 @@ def add_aggregated_resources(routine: Routine, aggregation_dict: Dict[str, Dict[

expanded_aggregation_dict = _expand_aggregation_dict(aggregation_dict)
for subroutine in routine.walk():
_add_aggregated_resources_to_subroutine(subroutine, expanded_aggregation_dict)
_add_aggregated_resources_to_subroutine(subroutine, expanded_aggregation_dict, remove_decomposed, backend)
return routine


def _add_aggregated_resources_to_subroutine(
subroutine: Routine, expanded_aggregation_dict: Dict[str, Dict[str, Any]], backend=BACKEND
subroutine: Routine, expanded_aggregation_dict: Dict[str, Dict[str, Any]], remove_decomposed: bool, backend=BACKEND
) -> Routine:

if not hasattr(subroutine, "resources") or not subroutine.resources:
return subroutine

Expand All @@ -73,8 +80,10 @@ def _add_aggregated_resources_to_subroutine(
value=str(multiplier_expr * resource_expr),
)
aggregated_resources[sub_res] = new_resource

del aggregated_resources[resource_name]
if remove_decomposed:
del aggregated_resources[resource_name]
else:
aggregated_resources[resource_name].type = ResourceType.other

subroutine.resources = aggregated_resources
return subroutine
Expand All @@ -87,8 +96,6 @@ def _expand_aggregation_dict(aggregation_dict: Dict[str, Dict[str, Any]], backen
Returns:
Dict[str, Dict[str, Any]]: The expanded aggregation dictionary.
"""
if not isinstance(aggregation_dict, dict):
raise TypeError("aggregation_dict must be a dictionary.")

sorted_resources = _topological_sort(aggregation_dict)

Expand Down
52 changes: 49 additions & 3 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _generate_test(subroutine, input_params, linked_params):


@pytest.mark.parametrize(
"aggregation_dict, generate_test_fn, expected_output",
"aggregation_dict, generate_test_fn, remove_decomposed, expected_output",
[
(
{"control_ry": {"rotation": 2, "CNOT": 2}, "rotation": {"T_gates": 50}},
Expand All @@ -79,6 +79,7 @@ def _generate_test(subroutine, input_params, linked_params):
["z", "num"],
[{"source": "z", "targets": ["ccry_gate.x"]}, {"source": "num", "targets": ["ccry_gate.num"]}],
),
True,
Routine(
name="test_qref",
input_params=["z", "num"],
Expand Down Expand Up @@ -118,6 +119,7 @@ def _generate_test(subroutine, input_params, linked_params):
["z", "num"],
[{"source": "z", "targets": ["arbitrary_z.x"]}, {"source": "num", "targets": ["arbitrary_z.num"]}],
),
True,
Routine(
name="test_qref",
input_params=["z", "num"],
Expand Down Expand Up @@ -152,10 +154,54 @@ def _generate_test(subroutine, input_params, linked_params):
],
),
),
(
{"control_ry": {"rotation": 2, "CNOT": 2}, "rotation": {"T_gates": 50}},
_generate_test(
ccry_gate,
["z", "num"],
[{"source": "z", "targets": ["ccry_gate.x"]}, {"source": "num", "targets": ["ccry_gate.num"]}],
),
False,
Routine(
name="test_qref",
input_params=["z", "num"],
children={
"ccry_gate": Routine(
name="ccry_gate",
type=None,
input_params=["x", "num"],
ports={
"in": {"name": "in", "direction": "input", "size": "n"},
"out": {"name": "out", "direction": "output", "size": "n"},
},
resources={
"CNOT": {"name": "CNOT", "type": "additive", "value": "8*num"},
"T_gates": {"name": "control_ry", "type": "additive", "value": "300*num"},
"control_ry": {
"name": "control_ry",
"type": "other",
"value": "3*num",
},
},
local_variables={"n": "x"},
)
},
type=None,
linked_params={"z": [("ccry_gate", "x")], "num": [("ccry_gate", "num")]},
ports={
"in": {"name": "in", "direction": "input", "size": "z"},
"out": {"name": "out", "direction": "output", "size": "z"},
},
connections=[
{"source": "in", "target": "ccry_gate.in"},
{"source": "ccry_gate.out", "target": "out"},
],
),
),
],
)
def test_add_aggregated_resources(aggregation_dict, generate_test_fn, expected_output):
result = add_aggregated_resources(generate_test_fn, aggregation_dict)
def test_add_aggregated_resources(aggregation_dict, generate_test_fn, remove_decomposed, expected_output):
result = add_aggregated_resources(generate_test_fn, aggregation_dict, remove_decomposed=remove_decomposed)
_compare_routines(result, expected_output)


Expand Down

0 comments on commit 37acec8

Please sign in to comment.