Skip to content

Commit

Permalink
feat: add postprocessing stages (#153)
Browse files Browse the repository at this point in the history
* feat: add postprocessing stages

* docs: update docs about postprocessing

* style: fix mypy issues
  • Loading branch information
mstechly authored Dec 13, 2024
1 parent cf5ce14 commit 0228662
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 16 deletions.
7 changes: 6 additions & 1 deletion docs/concepts/compilation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ of the global input symbols of top-level routine.
Compilation can be viewed as recursive process. At every recursive call, several things need to happen in correct order.
Below we outline how the compilation proceeds.

### Step 1: preprocessing
### Step 1: Preprocessing

The Bartiq's compilation engine makes several assumptions about the routine being compiled, which simplify its code
at the expense of flexibility. For instance, Bartiq assumes all port sizes are single parameters of size `#port_name`.
Expand Down Expand Up @@ -124,3 +124,8 @@ of resources of its children.
#### Step 2.6: Output port compilation

Finally, the output ports are compiled, and the new object representing compiled routine is created.

### Step 3: Postprocessing

After compilation is done, there might be certain operations that the user might want to perform on a compiled routine, e.g.: calculating some derived resources.
Currently, there are no postprocessing steps present by default.
13 changes: 10 additions & 3 deletions src/bartiq/compilation/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
evaluate_ports,
evaluate_resources,
)
from .postprocessing import DEFAULT_POSTPROCESSING_STAGES, PostprocessingStage
from .preprocessing import DEFAULT_PREPROCESSING_STAGES, PreprocessingStage

REPETITION_ALLOW_ARBITRARY_RESOURCES_ENV = "BARTIQ_REPETITION_ALLOW_ARBITRARY_RESOURCES"
Expand Down Expand Up @@ -94,6 +95,7 @@ def compile_routine(
*,
backend: SymbolicBackend[T] = sympy_backend,
preprocessing_stages: Iterable[PreprocessingStage[T]] = DEFAULT_PREPROCESSING_STAGES,
postprocessing_stages: Iterable[PostprocessingStage[T]] = DEFAULT_POSTPROCESSING_STAGES,
skip_verification: bool = False,
) -> CompilationResult[T]:
"""Performs symbolic compilation of a given routine.
Expand All @@ -106,6 +108,7 @@ def compile_routine(
backend: a backend used for manipulating symbolic expressions.
preprocessing_stages: functions used for preprocessing of a given routine to make sure it can be correctly
compiled by Bartiq.
postprocessing_stages: functions used for postprocessing of a given routine after compilation is done.
skip_verification: a flag indicating whether verification of the routine should be skipped.
Expand All @@ -122,9 +125,13 @@ def compile_routine(
)
root = routine if isinstance(routine, Routine) else Routine[T].from_qref(ensure_routine(routine), backend)

for stage in preprocessing_stages:
root = stage(root, backend)
return CompilationResult(routine=_compile(root, backend, {}, Context(root.name)), _backend=backend)
for pre_stage in preprocessing_stages:
root = pre_stage(root, backend)
compiled_routine = _compile(root, backend, {}, Context(root.name))
for post_stage in postprocessing_stages:
compiled_routine = post_stage(compiled_routine, backend)

return CompilationResult(routine=compiled_routine, _backend=backend)


def _compile_local_variables(
Expand Down
47 changes: 47 additions & 0 deletions src/bartiq/compilation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, TypeVar

from .._routine import CompiledRoutine
from ..symbolics.backend import SymbolicBackend
from ..transform import add_aggregated_resources

T = TypeVar("T")

PostprocessingStage = Callable[[CompiledRoutine[T], SymbolicBackend[T]], CompiledRoutine[T]]

DEFAULT_POSTPROCESSING_STAGES: list[PostprocessingStage] = []


def aggregate_resources(
aggregation_dict: dict[str, dict[str, Any]], remove_decomposed: bool = True
) -> PostprocessingStage[T]:
"""Returns a postprocessing stage which aggregates resources using `add_aggregated_resources` method.
This function is just a wrapper around `add_aggregated_resources` method from `bartiq.transform.
For more details how it works, please see its documentation.
Args
aggregation_dict: A dictionary that decomposes resources into more fundamental components along with their
respective multipliers.
remove_decomposed : Whether to remove the decomposed resources from the routine.
Defaults to True.
"""

def _inner(routine: CompiledRoutine[T], backend: SymbolicBackend[T]) -> CompiledRoutine[T]:
return add_aggregated_resources(routine, aggregation_dict, remove_decomposed, backend) # TODO: Konrad mypy

return _inner
19 changes: 16 additions & 3 deletions src/bartiq/compilation/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections import defaultdict
from dataclasses import replace
from typing import Callable, TypeVar

from bartiq.errors import BartiqPrecompilationError

from .._routine import Constraint, Port, PortDirection, Resource, ResourceType, Routine
from ..errors import BartiqPreprocessingError
from ..symbolics.backend import SymbolicBackend, TExpr

T = TypeVar("T")
Expand Down Expand Up @@ -154,7 +167,7 @@ def _sort_key(port: Port[T]) -> tuple[bool, str]:
)
]
if missing_symbols:
raise BartiqPrecompilationError(
raise BartiqPreprocessingError(
f"Size of the port {port.name} depends on symbols {sorted(missing_symbols)} which are undefined."
)
new_size = backend.substitute(port.size, additional_local_variables)
Expand Down
4 changes: 2 additions & 2 deletions src/bartiq/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.


class BartiqPrecompilationError(Exception):
"""Raised for errors during Bartiq function pre-compilation."""
class BartiqPreprocessingError(Exception):
"""Raised for errors during Bartiq function pre-processing."""


class BartiqCompilationError(Exception):
Expand Down
11 changes: 6 additions & 5 deletions src/bartiq/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@
from graphlib import TopologicalSorter
from typing import Any, TypeVar

from bartiq import Resource, ResourceType, Routine
from bartiq import CompiledRoutine, Resource, ResourceType, Routine
from bartiq.symbolics import sympy_backend
from bartiq.symbolics.backend import SymbolicBackend, TExpr

BACKEND = sympy_backend


T = TypeVar("T")
AnyRoutine = TypeVar("AnyRoutine", Routine, CompiledRoutine)


def add_aggregated_resources(
routine: Routine[T],
routine: AnyRoutine,
aggregation_dict: dict[str, dict[str, Any]],
remove_decomposed: bool = True,
backend: SymbolicBackend[T] = sympy_backend,
) -> Routine[T]:
) -> AnyRoutine:
"""Add aggregated resources to bartiq routine based on the aggregation dictionary.
Args:
Expand All @@ -60,11 +61,11 @@ def add_aggregated_resources(


def _add_aggregated_resources_to_subroutine(
subroutine: Routine[T],
subroutine: AnyRoutine,
expanded_aggregation_dict: dict[str, dict[str, str | TExpr[T]]],
remove_decomposed: bool,
backend: SymbolicBackend[T] = BACKEND,
) -> Routine[T]:
) -> AnyRoutine:
new_children = {
name: _add_aggregated_resources_to_subroutine(child, expanded_aggregation_dict, remove_decomposed, backend)
for name, child in subroutine.children.items()
Expand Down
14 changes: 14 additions & 0 deletions src/bartiq/verification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 PsiQuantum, Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

from qref.functools import accepts_all_qref_types
Expand Down
4 changes: 2 additions & 2 deletions tests/compilation/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from bartiq import compile_routine
from bartiq.compilation.preprocessing import introduce_port_variables
from bartiq.errors import BartiqCompilationError, BartiqPrecompilationError
from bartiq.errors import BartiqCompilationError, BartiqPreprocessingError


def load_compile_test_data():
Expand Down Expand Up @@ -217,6 +217,6 @@ def test_compilation_fails_if_input_ports_has_size_depending_on_undefined_variab
}

with pytest.raises(
BartiqPrecompilationError, match=r"Size of the port in_0 depends on symbols \['M', 'N'\] which are undefined."
BartiqPreprocessingError, match=r"Size of the port in_0 depends on symbols \['M', 'N'\] which are undefined."
):
compile_routine(routine, backend=backend)
63 changes: 63 additions & 0 deletions tests/compilation/test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from dataclasses import replace

from qref.schema_v1 import RoutineV1

from bartiq import compile_routine
from bartiq._routine import Routine
from bartiq.compilation.postprocessing import aggregate_resources


def _get_routine(backend):
qref_routine = RoutineV1(
name="root",
type=None,
children=[
{
"name": "child_1",
"type": None,
"resources": [
{"name": "a", "type": "additive", "value": 1},
{"name": "b", "type": "additive", "value": 5},
],
},
{
"name": "child_2",
"type": None,
"resources": [
{"name": "a", "type": "additive", "value": 2},
{"name": "b", "type": "additive", "value": 3},
{"name": "c", "type": "additive", "value": 1},
],
},
],
)
return Routine.from_qref(qref_routine, backend)


def test_aggregate_resources_as_postprocessing(backend):
routine = _get_routine(backend)
aggregation_dict = {"a": {"op": 1}, "b": {"op": 2}, "c": {"op": 3}}
postprocessing_stages = [aggregate_resources(aggregation_dict, remove_decomposed=True)]
compiled_routine = compile_routine(routine, postprocessing_stages=postprocessing_stages, backend=backend).routine
assert len(compiled_routine.resources) == 1
assert compiled_routine.resources["op"].value == 22


def test_two_postprocessing_stages(backend):
routine = _get_routine(backend)

def stage_1(routine, backend):
return replace(routine, name=routine.name.upper())

def stage_2(routine, backend):
cool_children = routine.children
for child_name, child in cool_children.items():
cool_children[child_name] = replace(child, type="cool_kid")
return replace(routine, children=cool_children)

postprocessing_stages = [stage_1, stage_2]
compiled_routine = compile_routine(routine, postprocessing_stages=postprocessing_stages, backend=backend).routine

assert compiled_routine.name == "ROOT"
for child in compiled_routine.children.values():
assert child.type == "cool_kid"

0 comments on commit 0228662

Please sign in to comment.