Skip to content

Commit

Permalink
feat: Enable casting to native types and obtaining backend-native fun…
Browse files Browse the repository at this point in the history
…ctions (#135)

* Implement "as_native" method

* Use as_native when converting native Routine to QREF

* Implement backend.func method for obtaining callable objects

* Fix typing mistake
  • Loading branch information
dexter2206 authored Oct 24, 2024
1 parent 759a367 commit 11628ff
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 88 deletions.
4 changes: 2 additions & 2 deletions src/bartiq/_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ def _endpoint_from_qref(endpoint: str) -> Endpoint:
def _port_to_qref(port: Port[T], backend: SymbolicBackend[T]) -> PortV1:
return PortV1(
name=port.name,
size=backend.serialize(port.size),
size=backend.as_native(port.size),
direction=cast(Literal["input", "output", "through"], port.direction),
)


def _resource_to_qref(resource: Resource[T], backend: SymbolicBackend[T]) -> ResourceV1:
return ResourceV1(name=resource.name, type=resource.type.value, value=backend.serialize(resource.value))
return ResourceV1(name=resource.name, type=resource.type.value, value=backend.as_native(resource.value))


def _endpoint_to_qref(endpoint: Endpoint) -> str:
Expand Down
6 changes: 6 additions & 0 deletions src/bartiq/symbolics/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class SymbolicBackend(Protocol[T]):
def as_expression(self, value: TExpr[T] | str) -> TExpr[T]:
"""Convert given value into an expression native to this backend."""

def as_native(self, expr: TExpr[T]) -> str | int | float:
"""Convert given expression as an instance of a native type."""

def free_symbols_in(self, expr: TExpr[T], /) -> Iterable[str]:
"""Return an iterable over free symbols in given expression."""

Expand Down Expand Up @@ -75,3 +78,6 @@ def compare(self, lhs: TExpr[T], rhs: TExpr[T]) -> ComparisonResult:
- `ComparisonResult.unequal': 'lhs' and 'rhs' are certainly not equal.
- `ComparisonResult.ambigous`: it is not known for certain if `lhs` and `rhs` are equal.
"""

def func(self, func_name: str) -> Callable[..., TExpr[T]]:
"""Obtain an implementation of a function with given name."""
10 changes: 10 additions & 0 deletions src/bartiq/symbolics/sympy_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def parse_constant(self, expr: Expr) -> TExpr[Expr]:

return expr

@identity_for_numbers
def as_native(self, expr: Expr) -> str | int | float:
return value if (value := self.value_of(expr)) is not None else self.serialize(expr)

@empty_for_numbers
def free_symbols_in(self, expr: Expr) -> Iterable[str]:
"""Return an iterable over free symbol names in given expression."""
Expand Down Expand Up @@ -201,6 +205,12 @@ def compare(self, lhs: TExpr[Expr], rhs: TExpr[Expr]) -> ComparisonResult:
else:
return ComparisonResult.ambigous

def func(self, func_name: str) -> Callable[..., TExpr[Expr]]:
try:
return SPECIAL_FUNCS[func_name]
except KeyError:
return sympy.Function(func_name)


# Define sympy_backend for backwards compatibility
sympy_backend = SympyBackend(parse_to_sympy)
Expand Down
4 changes: 2 additions & 2 deletions tests/compilation/data/compile/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -902,13 +902,13 @@
ports:
- direction: output
name: out_0
size: '1'
size: 1
type: null
- name: b
ports:
- direction: input
name: in_0
size: '1'
size: 1
type: null
connections:
- source: a.out_0
Expand Down
12 changes: 6 additions & 6 deletions tests/compilation/data/compile/passthroughs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@
ports:
- direction: input
name: in_0
size: '42'
size: 42
- direction: output
name: out_0
size: '42'
size: 42
type: null
connections:
- source: a.out_0
Expand All @@ -225,10 +225,10 @@
ports:
- direction: input
name: in_0
size: '42'
size: 42
- direction: output
name: out_0
size: '42'
size: 42
type: null
connections:
- source: a.out_0
Expand All @@ -239,10 +239,10 @@
ports:
- direction: input
name: in_0
size: '42'
size: 42
- direction: output
name: out_0
size: '42'
size: 42
type: null
version: v1
# Propagation of param through children (this is how we introduced passthroughs)
Expand Down
64 changes: 32 additions & 32 deletions tests/compilation/data/compile/ports_and_topology.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,20 @@
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: input
name: in_1
size: '2'
size: 2
- direction: output
name: out_0
size: '1'
size: 1
- direction: output
name: out_1
size: '2'
size: 2
resources:
- name: y
type: other
value: '3'
value: 3
type: null
connections:
- source: a.out_0
Expand All @@ -83,20 +83,20 @@
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: input
name: in_1
size: '2'
size: 2
- direction: output
name: out_0
size: '1'
size: 1
- direction: output
name: out_1
size: '2'
size: 2
resources:
- name: z
type: other
value: '3'
value: 3
type: null
version: v1
# Constant input register size with children inputs being described by the same variable
Expand Down Expand Up @@ -156,20 +156,20 @@
ports:
- direction: input
name: in_0
size: '2'
size: 2
- direction: input
name: in_1
size: '2'
size: 2
- direction: output
name: out_0
size: '2'
size: 2
- direction: output
name: out_1
size: '2'
size: 2
resources:
- name: y
type: other
value: '4'
value: 4
type: null
connections:
- source: a.out_0
Expand All @@ -184,20 +184,20 @@
ports:
- direction: input
name: in_0
size: '2'
size: 2
- direction: input
name: in_1
size: '2'
size: 2
- direction: output
name: out_0
size: '2'
size: 2
- direction: output
name: out_1
size: '2'
size: 2
resources:
- name: z
type: other
value: '4'
value: 4
type: null
version: v1
# Constant register size comes from non-root
Expand Down Expand Up @@ -236,16 +236,16 @@
ports:
- direction: output
name: out_0
size: '1'
size: 1
type: null
- name: b
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: output
name: out_0
size: '1'
size: 1
type: null
connections:
- source: a.out_0
Expand All @@ -256,7 +256,7 @@
ports:
- direction: output
name: out_0
size: '1'
size: 1
type: null
version: v1
# Parent's and child's ports are connected and the port sizes are defined in both cases (not None)
Expand Down Expand Up @@ -475,13 +475,13 @@
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: input
name: in_1
size: N
- direction: output
name: out_0
size: '1'
size: 1
- direction: output
name: out_1
size: N
Expand All @@ -501,13 +501,13 @@
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: input
name: in_1
size: N
- direction: output
name: out_0
size: '1'
size: 1
- direction: output
name: out_1
size: N
Expand All @@ -527,13 +527,13 @@
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: input
name: in_1
size: N
- direction: output
name: out_0
size: '1'
size: 1
- direction: output
name: out_1
size: N
Expand All @@ -553,13 +553,13 @@
ports:
- direction: input
name: in_0
size: '1'
size: 1
- direction: input
name: in_1
size: N
- direction: output
name: out_0
size: '1'
size: 1
- direction: output
name: out_1
size: N
Expand Down
6 changes: 3 additions & 3 deletions tests/compilation/data/evaluate/constants.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
resources:
- name: T_gates
type: additive
value: '3.14159265358979'
value: 3.14159265358979
type: null
version: v1
# Resource Q is assigned a constant value of sin(pi/2)
Expand All @@ -35,7 +35,7 @@
resources:
- name: T_gates
type: additive
value: "1"
value: 1
type: null
version: v1
# Resource Q is assigned a constant value of 5*e
Expand All @@ -55,6 +55,6 @@
resources:
- name: T_gates
type: additive
value: "13.5914091422952"
value: 13.5914091422952
type: null
version: v1
Loading

0 comments on commit 11628ff

Please sign in to comment.