Skip to content

Commit

Permalink
fix: assigning value to evaluated expressions and pruning linked_para…
Browse files Browse the repository at this point in the history
…ms (#52)

* fix: assigning value to evaluated expressions and pruning linked_params in evaluation

* fix: update notebook
  • Loading branch information
mstechly authored May 29, 2024
1 parent 1dfac28 commit 0c9c872
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 22 deletions.
34 changes: 19 additions & 15 deletions docs/tutorials/02_alias_sampling_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"<div class=\"alert alert-block alert-info admonition note\"> <p class=\"admonition-title\"><b>NOTE:</b></p>\n",
"\n",
"This tutorial, as well as all the other tutorials, has been written as a jupyter notebook.\n",
"If you're reading it online, you can either keep reading, or go to `docs/tutorials` to explore them in a more interactive way!\n",
"If you're reading it online, you can either keep reading, or clone the repository and go to `docs/tutorials` to explore them in a more interactive way!\n",
"\n",
"</div>"
]
Expand Down Expand Up @@ -330,8 +330,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"rotations: 2\n",
"T_gates: 4*L + 8*L/multiplicity(2, L) + 4*mu + swap.O(log2(L)) - 8\n"
"T_gates: 4*L + 8*L/multiplicity(2, L) + 4*mu + swap.O(log2(L)) - 8\n",
"rotations: 2\n"
]
}
],
Expand Down Expand Up @@ -367,8 +367,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"rotations: 2\n",
"T_gates: swap.O(log2(120)) + 824\n"
"T_gates: swap.O(log2(120)) + 824\n",
"rotations: 2\n"
]
}
],
Expand Down Expand Up @@ -403,8 +403,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"rotations: 2\n",
"T_gates: 4*L + 8*L/multiplicity(2, L) + 4*mu + O(log2(L)) - 8\n"
"T_gates: 4*L + 8*L/multiplicity(2, L) + 4*mu + O(log2(L)) - 8\n",
"rotations: 2\n"
]
}
],
Expand All @@ -419,7 +419,7 @@
"id": "e090cce2-8b9f-449f-8c14-96212cbe7f85",
"metadata": {},
"source": [
"We still have big O there, but at least now we got rid of the `swap`. So let's assume the simplest case, i.e.`O(x) = x`."
"We still have big O there, but at least now we got rid of the `swap`. So let's assume the simplest case, i.e.`O(x) = ceiling(x)` "
]
},
{
Expand All @@ -432,14 +432,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"rotations: 2\n",
"T_gates: log2(120) + 824\n"
"T_gates: 831\n",
"rotations: 2\n"
]
}
],
"source": [
"import math\n",
"\n",
"def big_O(x):\n",
" return x\n",
" return math.ceil(x)\n",
"\n",
"functions_map = {\"O\": big_O}\n",
"evaluated_routine = evaluate(compiled_routine, assignments, functions_map=functions_map)\n",
Expand All @@ -465,8 +467,10 @@
"If we just interact with bare python objects, getting a quick idea of the values of various fields might be a bit cumbersome.\n",
"That's where `explore_routine` functions might be helpful. Try it out using the snippet below.\n",
"\n",
"\n",
"<div class=\"alert alert-block alert-info admonition note\"> <p class=\"admonition-title\"><b>NOTE:</b></p>\n",
"This is an interactive feature and will not render in the static version of the docs. To use it you need to run this tutorial as a jupyter notebook.\n",
"This is an interactive feature and will not render in the static version of the docs. To use it you need to run this tutorial as a jupyter notebook. <br>\n",
"Remember to install bartiq with <code>pip install bartiq[jupyter]</code> to make sure you have all the dependencies needed for these widgets to work (for more details visit <a href=\"https://psiq.github.io/bartiq/latest/installation/\">installation docs</a>).\n",
"</div>\n"
]
},
Expand All @@ -479,7 +483,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0abdf57a3f6c40318de80a701c9e3610",
"model_id": "811e911faa5847d59817f59b97115522",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -529,13 +533,13 @@
"&\\text{temp\\_2} = 8\\\\\n",
"&\\text{temp\\_3} = 1\\newline\n",
"&\\underline{\\text{Resources:}}\\\\\n",
"&T_{\\text{gates}} = 831\\\\\n",
"&rotations = 2\\\\\n",
"&T_{\\text{gates}} = \\operatorname{log}_{2}{\\left(120 \\right)} + 824\\\\\n",
"&\\text{usp}.\\!T_{\\text{gates}} = 320\\\\\n",
"&\\text{usp}.\\!rotations = 2\\\\\n",
"&\\text{qrom}.\\!T_{\\text{gates}} = 476\\\\\n",
"&\\text{compare}.\\!T_{\\text{gates}} = 28\\\\\n",
"&\\text{swap}.\\!T_{\\text{gates}} = \\operatorname{log}_{2}{\\left(120 \\right)}\n",
"&\\text{swap}.\\!T_{\\text{gates}} = 7\n",
"\\end{align}$"
],
"text/plain": [
Expand Down
2 changes: 0 additions & 2 deletions src/bartiq/compilation/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def _evaluate(
for parsed_assignment in parsed_assignments:
_evaluate_over_assignment(evaluated_routine, parsed_assignment, backend, functions_map)

# TODO: This is just for backward compatibility and making sure tests pass. # noqa:T101
evaluated_routine.linked_params = {}
return evaluated_routine


Expand Down
10 changes: 9 additions & 1 deletion src/bartiq/compilation/_symbolic_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ def update_routine_with_symbolic_function(routine: Routine, function: SymbolicFu
input_params, input_register_sizes_from_inputs = _parse_function_inputs(function)
costs, registers_sizes_from_outputs = _parse_function_outputs(function, input_register_sizes_from_inputs)
routine.input_params = sorted(input_params)
linked_params_to_remove = set(routine.linked_params.keys()) - set(input_params)
for param in linked_params_to_remove:
del routine.linked_params[param]

for port_name, port_size in input_register_sizes_from_inputs.items():
routine.input_ports[port_name].size = str(port_size)
for port_name, port_size in registers_sizes_from_outputs.items():
Expand Down Expand Up @@ -769,6 +773,10 @@ def _parse_function_outputs(function, input_register_sizes_from_inputs):
f"got {type(output_variable)}"
)
else:
costs.append(f"{output_symbol} = {output_variable.evaluated_expression}")
cost_value = (
output_variable.evaluated_expression if output_variable.value is None else output_variable.value
)

costs.append(f"{output_symbol} = {cost_value}")

return costs, register_sizes
190 changes: 186 additions & 4 deletions tests/compilation/data/evaluate_test_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,42 @@
"type": null
}
],
[
{
"name": "",
"type": null,
"input_params": [
"x"
],
"resources": {
"Q": {
"name": "Q",
"type": "other",
"value": {
"type": "str",
"value": "log2(x)"
}
}
}
},
[
"x=120"
],
{
"name": "",
"type": null,
"resources": {
"Q": {
"name": "Q",
"type": "other",
"value": {
"type": "str",
"value": "6.90689059560852"
}
}
}
}
],
[
{
"name": "",
Expand Down Expand Up @@ -1800,7 +1836,10 @@
"value": "N"
}
}
}
},
"input_params": [
"y"
]
}
},
"connections": [
Expand All @@ -1810,14 +1849,21 @@
}
],
"input_params": [
"x"
"x",
"y"
],
"linked_params": {
"x": [
[
"a",
"x"
]
],
"y": [
[
"b",
"y"
]
]
}
},
Expand Down Expand Up @@ -1854,15 +1900,29 @@
"value": "0"
}
}
}
},
"input_params": [
"y"
]
}
},
"connections": [
{
"source": "a.out_0",
"target": "b.in_0"
}
]
],
"input_params": [
"y"
],
"linked_params": {
"y": [
[
"b",
"y"
]
]
}
}
],
[
Expand Down Expand Up @@ -1936,5 +1996,127 @@
}
}
}
],
[
{
"name": "",
"type": null,
"input_params": [
"x",
"y"
],
"linked_params": {
"x": [
[
"a",
"x"
]
],
"y": [
[
"a",
"y"
]
]
},
"children": {
"a": {
"name": "a",
"type": null,
"input_params": [
"x",
"y"
],
"linked_params": {
"x": [
[
"b",
"x"
]
],
"y": [
[
"b",
"y"
]
]
},
"children": {
"b": {
"name": "b",
"type": null,
"input_params": [
"x",
"y"
],
"resources": {
"Q": {
"name": "Q",
"type": "other",
"value": {
"type": "str",
"value": "x + y"
}
}
}
}
}
}
}
},
[
"x=10"
],
{
"name": "",
"type": null,
"input_params": [
"y"
],
"linked_params": {
"y": [
[
"a",
"y"
]
]
},
"children": {
"a": {
"name": "a",
"type": null,
"input_params": [
"y"
],
"linked_params": {
"y": [
[
"b",
"y"
]
]
},
"children": {
"b": {
"name": "b",
"type": null,
"input_params": [
"y"
],
"resources": {
"Q": {
"name": "Q",
"type": "other",
"value": {
"type": "str",
"value": "y + 10"
}
}
}
}
}
}
}
}
]
]

0 comments on commit 0c9c872

Please sign in to comment.