Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add repeated structures #136

Merged
merged 14 commits into from
Nov 7, 2024
5 changes: 5 additions & 0 deletions docs/concepts/compilation.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ populated before other children are compiled.
Once this step is completed, we can be sure that all resources and ports of each child are expressed in terms
of global variables, which is a requirement for the next step.

#### Step 2.4: Compilation of repetition

In case a routine is repeated (i.e. has a non-empty `repetition` field), its resources get updated according
to the repetition rules and the repetition specification itself gets updated using the parameter map.

#### Step 2.5: Resource compilation

Resources of the routine are compiled, which is possible only at this stage, as they can be expressed in terms
Expand Down
1,440 changes: 727 additions & 713 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python = "^3.10"
pydantic = "^2.7"
sympy = "^1.12"
pyparsing ="~3.1.2"
qref = "0.8.0"
qref = "0.9.0"

# A list of all of the optional dependencies, some of which are included in the
# below `extras`. They can be opted into by apps.
Expand Down
6 changes: 6 additions & 0 deletions src/bartiq/_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from qref.schema_v1 import PortV1, ResourceV1, RoutineV1
from typing_extensions import Self, TypedDict

from .repetitions import Repetition, repetition_from_qref, repetition_to_qref
from .symbolics.backend import SymbolicBackend, TExpr

T = TypeVar("T")
Expand Down Expand Up @@ -85,6 +86,7 @@ class _CommonRoutineParams(TypedDict, Generic[T]):
input_params: Iterable[str]
ports: dict[str, Port[T]]
resources: dict[str, Resource[T]]
repetition: Repetition[T] | None
connections: dict[Endpoint, Endpoint]


Expand All @@ -99,6 +101,7 @@ class Routine(Generic[T]):
ports: dict[str, Port[T]]
resources: dict[str, Resource[T]]
connections: dict[Endpoint, Endpoint]
repetition: Repetition | None = None
constraints: Iterable[Constraint[T]] = ()

@property
Expand Down Expand Up @@ -144,6 +147,7 @@ class CompiledRoutine(Generic[T]):
ports: dict[str, Port[T]]
resources: dict[str, Resource[T]]
connections: dict[Endpoint, Endpoint]
repetition: Repetition | None = None
constraints: Iterable[Constraint[T]] = ()

@classmethod
Expand All @@ -164,6 +168,7 @@ def _common_routine_dict_from_qref(qref_obj: AnyQrefType, backend: SymbolicBacke
"ports": {port.name: _port_from_qref(port, backend) for port in program.ports},
"input_params": tuple(program.input_params),
"resources": {resource.name: _resource_from_qref(resource, backend) for resource in program.resources},
"repetition": repetition_from_qref(program.repetition, backend),
"connections": {
_endpoint_from_qref(conn.source): _endpoint_from_qref(conn.target) for conn in program.connections
},
Expand Down Expand Up @@ -225,5 +230,6 @@ def _routine_to_qref_program(routine: Routine[T] | CompiledRoutine[T], backend:
{"source": _endpoint_to_qref(source), "target": _endpoint_to_qref(target)}
for source, target in routine.connections.items()
],
repetition=repetition_to_qref(routine.repetition, backend),
**kwargs,
)
2 changes: 1 addition & 1 deletion src/bartiq/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def cost_func_callable(x) -> float:
if not isinstance(x, (int, float)):
x = x[0]

substituted_expr = backend.substitute(expression, param, x)
substituted_expr = backend.substitute(expression, {param: x})
result = backend.value_of(substituted_expr)
return float(result)

Expand Down
33 changes: 17 additions & 16 deletions src/bartiq/compilation/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, replace
from typing import Callable, TypeVar

from .._routine import Constraint, ConstraintStatus, Port, Resource
from .._routine import Constraint, ConstraintStatus, Port, Repetition, Resource
from ..symbolics.backend import ComparisonResult, SymbolicBackend, TExpr

T = TypeVar("T", covariant=True)
Expand All @@ -40,15 +40,6 @@ def __init__(self, original_constraint: Constraint[T], compiled_constraint: Cons
super().__init__(original_constraint, compiled_constraint)


def _evaluate_and_define_functions(
expr: TExpr[T], inputs: dict[str, TExpr[T]], custom_funcs: FunctionsMap[T], backend: SymbolicBackend[T]
) -> TExpr[T]:
expr = backend.substitute_all(expr, inputs)
for func_name, func in custom_funcs.items():
expr = backend.define_function(expr, func_name, func)
return value if (value := backend.value_of(expr)) is not None else expr


def evaluate_ports(
ports: dict[str, Port[T]],
inputs: dict[str, TExpr[T]],
Expand All @@ -57,9 +48,7 @@ def evaluate_ports(
) -> dict[str, Port[T]]:
custom_funcs = {} if custom_funcs is None else custom_funcs
return {
name: replace(
port, size=_evaluate_and_define_functions(port.size, inputs, custom_funcs, backend) # type: ignore
)
name: replace(port, size=backend.substitute(port.size, inputs, custom_funcs)) # type: ignore
for name, port in ports.items()
}

Expand All @@ -74,17 +63,29 @@ def evaluate_resources(
return {
name: replace(
resource,
value=_evaluate_and_define_functions(resource.value, inputs, custom_funcs, backend), # type: ignore
value=backend.substitute(resource.value, inputs, custom_funcs), # type: ignore
)
for name, resource in resources.items()
}


def evaluate_repetition(
repetition: Repetition[T] | None,
inputs: dict[str, TExpr[T]],
backend: SymbolicBackend[T],
custom_funcs: FunctionsMap[T] | None = None,
) -> Repetition[T] | None:
if repetition is not None:
return repetition.substitute_symbols(inputs, backend, functions_map=custom_funcs)
else:
return None


def _evaluate_constraint(
constraint: Constraint[T], inputs: dict[str, TExpr[T]], backend: SymbolicBackend[T], custom_funcs: FunctionsMap[T]
) -> Constraint[T]:
lhs = _evaluate_and_define_functions(constraint.lhs, inputs, custom_funcs, backend)
rhs = _evaluate_and_define_functions(constraint.rhs, inputs, custom_funcs, backend)
lhs = backend.substitute(constraint.lhs, inputs, custom_funcs)
rhs = backend.substitute(constraint.rhs, inputs, custom_funcs)

if (comparison_result := backend.compare(lhs, rhs)) == ComparisonResult.equal:
status = ConstraintStatus.satisfied
Expand Down
70 changes: 61 additions & 9 deletions src/bartiq/compilation/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, replace
from graphlib import TopologicalSorter
from typing import Generic, TypeVar

Expand All @@ -24,10 +24,19 @@
from qref.schema_v1 import RoutineV1
from qref.verification import verify_topology

from .._routine import CompiledRoutine, Endpoint, Port, Routine, routine_to_qref
from .._routine import (
CompiledRoutine,
Endpoint,
Port,
Repetition,
Resource,
Routine,
routine_to_qref,
)
from ..errors import BartiqCompilationError
from ..symbolics import sympy_backend
from ..symbolics.backend import SymbolicBackend, TExpr
from ..verification import verify_uncompiled_repetitions
from ._common import (
ConstraintValidationError,
Context,
Expand Down Expand Up @@ -97,10 +106,14 @@ def compile_routine(

"""
if not skip_verification and not isinstance(routine, Routine):
if not (verification_result := verify_topology(routine)):
problems = [problem + "\n" for problem in verification_result.problems]
problems = []
if not (topology_verification_result := verify_topology(routine)):
problems += [problem + "\n" for problem in topology_verification_result.problems]
if not (repetitions_verification_result := verify_uncompiled_repetitions(routine)):
problems += [problem + "\n" for problem in repetitions_verification_result.problems]
if len(problems) > 0:
raise BartiqCompilationError(
f"Found the following issues with the provided routine before the compilation started: {problems}",
f"Found the following issues with the provided routine before the compilation started: \n {problems}",
)
root = routine if isinstance(routine, Routine) else Routine[T].from_qref(ensure_routine(routine), backend)

Expand All @@ -120,7 +133,7 @@ def _compile_local_variables(
compiled_variables: dict[str, TExpr[T]] = {}
extended_inputs = inputs.copy()
for variable in TopologicalSorter(predecessors).static_order():
compiled_value = backend.substitute_all(local_variables[variable], extended_inputs)
compiled_value = backend.substitute(local_variables[variable], extended_inputs)
extended_inputs[variable] = compiled_variables[variable] = compiled_value
return compiled_variables

Expand All @@ -131,7 +144,7 @@ def _compile_linked_params(
parameter_map: ParameterTree[TExpr[T]] = defaultdict(dict)

for source, targets in linked_params.items():
evaluated_source = backend.substitute_all(backend.as_expression(source), inputs)
evaluated_source = backend.substitute(backend.as_expression(source), inputs)
for child, param in targets:
parameter_map[child][param] = evaluated_source

Expand Down Expand Up @@ -159,6 +172,35 @@ def _param_tree_from_compiled_ports(
return param_map


def _process_repeated_resources(
repetition: Repetition,
resources: dict[str, Resource],
children: Sequence[CompiledRoutine[T]],
backend: SymbolicBackend[T],
) -> dict[str, Resource]:
assert len(children) == 1, "Routine with repetition can only have one child."
new_resources = {}
child_resources = children[0].resources

# Ensure that routine with repetition only contains resources that we will later overwrite
for resource_name, resource in resources.items():
assert resource_name in child_resources
assert backend.serialize(resource.value) == f"{children[0].name}.{resource.name}"
for resource in child_resources.values():
if resource.type == "additive":
new_value = repetition.sequence_sum(resource.value, backend)
elif resource.type == "multiplicative":
new_value = repetition.sequence_prod(resource.value, backend)
else:
raise BartiqCompilationError(
f'Can\'t process resource "{resource.name}" of type "{resource.type}" in repetitive structure.'
)

new_resource = replace(resource, value=new_value)
new_resources[resource.name] = new_resource
return new_resources


def _compile(
routine: Routine[T],
backend: SymbolicBackend[T],
Expand Down Expand Up @@ -215,7 +257,16 @@ def _compile(

parameter_map[None] = {**parameter_map[None], **children_variables}

new_resources = evaluate_resources(routine.resources, parameter_map[None], backend)
resources = routine.resources
repetition = routine.repetition

if routine.repetition is not None:
resources = _process_repeated_resources(
routine.repetition, resources, list(compiled_children.values()), backend
)
repetition = routine.repetition.substitute_symbols(parameter_map[None], backend=backend)

new_resources = evaluate_resources(resources, parameter_map[None], backend)

compiled_ports = {
**compiled_ports,
Expand All @@ -239,4 +290,5 @@ def _compile(
resources=new_resources,
constraints=new_constraints,
connections=routine.connections,
repetition=repetition,
)
2 changes: 2 additions & 0 deletions src/bartiq/compilation/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Context,
evaluate_constraints,
evaluate_ports,
evaluate_repetition,
evaluate_resources,
)

Expand Down Expand Up @@ -109,6 +110,7 @@ def _evaluate_internal(
ports=evaluate_ports(compiled_routine.ports, inputs, backend, functions_map),
resources=evaluate_resources(compiled_routine.resources, inputs, backend, functions_map),
constraints=new_constraints,
repetition=evaluate_repetition(compiled_routine.repetition, inputs, backend, functions_map),
children={
name: _evaluate_internal(
child, inputs, backend=backend, functions_map=functions_map, context=context.descend(name)
Expand Down
41 changes: 29 additions & 12 deletions src/bartiq/compilation/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from collections import defaultdict
from dataclasses import replace
from typing import Callable, TypeVar
Expand Down Expand Up @@ -33,11 +34,11 @@ def _inner(routine: Routine[T], backend: SymbolicBackend[T]) -> Routine[T]:


@postorder_transform
def add_default_additive_resources(routine: Routine[T], backend: SymbolicBackend[T]) -> Routine[T]:
"""Adds additive resources to all the ancestors of a particular having this resource.
def propagate_child_resources(routine: Routine[T], backend: SymbolicBackend[T]) -> Routine[T]:
"""Propagate additive and multiplicative resources to all the ancestors of a child having these resources.

Since additive resources follow simple rules (value of a resource is equal to sum of the resources
of it's children), rather than defining it for all the subroutines, we can just have it defined for
Since additive/multiplicative resources follow simple rules (value of a resource is equal to sum/product of
the resources of it's children), rather than defining it for all the subroutines, we can just have it defined for
appropriate leaves and then "bubble it up" using this preprocessing transformation.

Args:
Expand All @@ -47,14 +48,17 @@ def add_default_additive_resources(routine: Routine[T], backend: SymbolicBackend
Returns:
A routine with all the additive resources defined appropriately at all levels of the hierarchy.
"""
child_resources_map: defaultdict[str, set[str]] = defaultdict(set)
child_additive_resources_map: defaultdict[str, set[str]] = defaultdict(set)
child_multiplicative_resources_map: defaultdict[str, set[str]] = defaultdict(set)

for child in routine.children.values():
for resource in child.resources.values():
if resource.type == ResourceType.additive:
child_resources_map[resource.name].add(child.name)
child_additive_resources_map[resource.name].add(child.name)
if resource.type == ResourceType.multiplicative:
child_multiplicative_resources_map[resource.name].add(child.name)

additional_resources: dict[str, Resource[T]] = {
additive_resources: dict[str, Resource[T]] = { # TODO: try removing & adding [T] to the Resource below?
res_name: Resource(
name=res_name,
type=ResourceType.additive,
Expand All @@ -63,11 +67,24 @@ def add_default_additive_resources(routine: Routine[T], backend: SymbolicBackend
0,
),
)
for res_name, children in child_resources_map.items()
for res_name, children in child_additive_resources_map.items()
if res_name not in routine.resources
}

return replace(routine, resources={**routine.resources, **additional_resources})
multiplicative_resources: dict[str, Resource[T]] = {
res_name: Resource(
name=res_name,
type=ResourceType.multiplicative,
value=math.prod(
(backend.as_expression(f"{child_name}.{res_name}") for child_name in children), # type: ignore
),
)
for res_name, children in child_multiplicative_resources_map.items()
if res_name not in routine.resources
}

extra_resources = {**additive_resources, **multiplicative_resources}
return replace(routine, resources={**routine.resources, **extra_resources})


@postorder_transform
Expand Down Expand Up @@ -140,7 +157,7 @@ def _sort_key(port: Port[T]) -> tuple[bool, str]:
raise BartiqPrecompilationError(
f"Size of the port {port.name} depends on symbols {sorted(missing_symbols)} which are undefined."
)
new_size = backend.substitute_all(port.size, additional_local_variables)
new_size = backend.substitute(port.size, additional_local_variables)
new_ports[port.name] = replace(port, size=new_size)
additional_constraints.append(Constraint(new_variable, new_size))
new_ports[port.name] = replace(port, size=new_variable)
Expand Down Expand Up @@ -170,7 +187,7 @@ def introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T]) -


def propagate_linked_params(routine: Routine[T], backend: SymbolicBackend[T]) -> Routine[T]:
"""Turns parameter links of level deeper than one into series of direct links.
"""Turns parameter links of level deeper than one into sequence of direct links.

Args:
routine: routine to be preprocessed
Expand Down Expand Up @@ -210,7 +227,7 @@ def propagate_linked_params(routine: Routine[T], backend: SymbolicBackend[T]) ->


DEFAULT_PREPROCESSING_STAGES = (
add_default_additive_resources,
propagate_child_resources,
propagate_linked_params,
promote_unlinked_inputs,
introduce_port_variables,
Expand Down
Loading
Loading