Skip to content

Commit 1377abd

Browse files
authored
feat: Support batch_execute for BraketLocalQubitDevice (#269)
1 parent dfff7a6 commit 1377abd

File tree

4 files changed

+254
-100
lines changed

4 files changed

+254
-100
lines changed

doc/devices/braket_local.rst

+26
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,32 @@ When executed, the circuit will perform the computation on the local machine.
4242
>>> circuit(0.2, 0.1, 0.3)
4343
array([0.97517033, 0.04904283])
4444

45+
Enabling the parallel execution of multiple circuits
46+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47+
48+
Where supported by the backend of the local simulator, the local device can be used to execute multiple
49+
quantum circuits in parallel. To unlock this feature, instantiate the device using the ``parallel=True`` argument:
50+
51+
>>> local_device = qml.device('braket.local.qubit', [... ,] parallel=True)
52+
53+
The details of the parallelization scheme depend on the PennyLane version you use, as well as the specific local simulator
54+
backend you use.
55+
56+
For example, PennyLane 0.13.0 and higher supports the parallel execution of circuits created during the computation of gradients.
57+
Just by instantiating the remote device with the ``parallel=True`` option, this feature is automatically used and can
58+
lead to significant speedups of your optimization pipeline.
59+
60+
The maximum number of circuits that can be executed in parallel is specified by the ``max_parallel`` argument.
61+
62+
>>> local_device = qml.device('braket.local.qubit', [... ,] parallel=True, max_parallel=20)
63+
64+
If ``max_parallel`` is not specified, the local simulator backend will use its own default. Each parallel execution
65+
will use additional memory, so be careful not to set ``max_parallel`` so high that you run out of memory on your local
66+
device. The exact limit will depend on your device. Additionally, setting ``max_parallel`` much higher than the number of
67+
CPU cores available (if you are using a CPU-based local simulator) or GPUs/GPU streams (if you are using a GPU-based local
68+
simulator) will not improve and may even degrade performance as too many parallel workers begin to contend for the same
69+
scarce resources.
70+
4571
Device options
4672
~~~~~~~~~~~~~~
4773

src/braket/pennylane_plugin/braket_device.py

+89-54
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from braket.devices import Device, LocalSimulator
5252
from braket.simulator import BraketSimulator
5353
from braket.tasks import GateModelQuantumTaskResult, QuantumTask
54+
from braket.tasks.local_quantum_task_batch import LocalQuantumTaskBatch
5455
from pennylane import QuantumFunctionError, QubitDevice
5556
from pennylane import numpy as np
5657
from pennylane.gradients import param_shift
@@ -105,6 +106,12 @@ class BraketQubitDevice(QubitDevice):
105106
execution.
106107
verbatim (bool): Whether to run tasks in verbatim mode. Note that verbatim mode only
107108
supports the native gate set of the device. Default False.
109+
parallel (bool): Whether to run tasks in parallel if supported by the device backend.
110+
Default False.
111+
max_parallel (int, optional): Maximum number of tasks to run on AWS in parallel.
112+
Batch creation will fail if this value is greater than the maximum allowed concurrent
113+
tasks on the device. If unspecified, uses defaults defined in ``AwsDevice``.
114+
Ignored if ``parallel=False``.
108115
parametrize_differentiable (bool): Whether to bind differentiable parameters (parameters
109116
marked with ``required_grad=True``) on the Braket device rather than in PennyLane.
110117
Default: True.
@@ -124,6 +131,8 @@ def __init__(
124131
shots: Union[int, None],
125132
noise_model: Optional[NoiseModel] = None,
126133
verbatim: bool = False,
134+
parallel: bool = False,
135+
max_parallel: Optional[int] = None,
127136
parametrize_differentiable: bool = True,
128137
**run_kwargs,
129138
):
@@ -139,6 +148,8 @@ def __init__(
139148

140149
super().__init__(wires, shots=shots or None)
141150
self._device = device
151+
self._parallel = parallel
152+
self._max_parallel = max_parallel
142153
self._circuit = None
143154
self._task = None
144155
self._noise_model = noise_model
@@ -179,6 +190,49 @@ def task(self) -> QuantumTask:
179190
"""QuantumTask: The task corresponding to the last run circuit."""
180191
return self._task
181192

193+
@property
194+
def parallel(self) -> bool:
195+
"""bool: Whether the device supports parallel execution of batches."""
196+
return self._parallel
197+
198+
def batch_execute(self, circuits, **run_kwargs):
199+
if not self._parallel:
200+
return super().batch_execute(circuits)
201+
202+
for circuit in circuits:
203+
self.check_validity(circuit.operations, circuit.observables)
204+
all_trainable = []
205+
braket_circuits = []
206+
for circuit in circuits:
207+
trainable = (
208+
BraketQubitDevice._get_trainable_parameters(circuit)
209+
if self._parametrize_differentiable
210+
else {}
211+
)
212+
all_trainable.append(trainable)
213+
braket_circuits.append(
214+
self._pl_to_braket_circuit(
215+
circuit,
216+
trainable_indices=frozenset(trainable.keys()),
217+
**run_kwargs,
218+
)
219+
)
220+
221+
batch_shots = 0 if self.analytic else self.shots
222+
223+
batch_inputs = (
224+
[{f"p_{k}": v for k, v in trainable.items()} for trainable in all_trainable]
225+
if self._parametrize_differentiable
226+
else []
227+
)
228+
229+
braket_results_batch = self._run_task_batch(braket_circuits, batch_shots, batch_inputs)
230+
231+
return [
232+
self._braket_to_pl_result(braket_result, circuit)
233+
for braket_result, circuit in zip(braket_results_batch, circuits)
234+
]
235+
182236
def _pl_to_braket_circuit(
183237
self,
184238
circuit: QuantumTape,
@@ -242,6 +296,17 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
242296
)
243297
return braket_circuit
244298

299+
def _update_tracker_for_batch(
300+
self, task_batch: Union[AwsQuantumTaskBatch, LocalQuantumTaskBatch], batch_shots: int
301+
):
302+
for task in task_batch.tasks:
303+
tracking_data = self._tracking_data(task)
304+
self.tracker.update(**tracking_data)
305+
total_executions = len(task_batch.tasks) - len(task_batch.unsuccessful)
306+
total_shots = total_executions * batch_shots
307+
self.tracker.update(batches=1, executions=total_executions, shots=total_shots)
308+
self.tracker.record()
309+
245310
def statistics(
246311
self, braket_result: GateModelQuantumTaskResult, measurements: Sequence[MeasurementProcess]
247312
) -> list[float]:
@@ -525,10 +590,6 @@ class BraketAwsQubitDevice(BraketQubitDevice):
525590
interactions with AWS services, to be supplied if extra control
526591
is desired. Default: None
527592
Default: False
528-
max_parallel (int, optional): Maximum number of tasks to run on AWS in parallel.
529-
Batch creation will fail if this value is greater than the maximum allowed concurrent
530-
tasks on the device. If unspecified, uses defaults defined in ``AwsDevice``.
531-
Ignored if ``parallel=False``.
532593
max_connections (int): The maximum number of connections in the Boto3 connection pool.
533594
Also the maximum number of thread pool workers for the batch.
534595
Ignored if ``parallel=False``.
@@ -553,8 +614,6 @@ def __init__(
553614
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
554615
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
555616
aws_session: Optional[AwsSession] = None,
556-
parallel: bool = False,
557-
max_parallel: Optional[int] = None,
558617
max_connections: int = AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT,
559618
max_retries: int = AwsQuantumTaskBatch.MAX_RETRIES,
560619
**run_kwargs,
@@ -579,8 +638,6 @@ def __init__(
579638
self._s3_folder = s3_destination_folder
580639
self._poll_timeout_seconds = poll_timeout_seconds
581640
self._poll_interval_seconds = poll_interval_seconds
582-
self._parallel = parallel
583-
self._max_parallel = max_parallel
584641
self._max_connections = max_connections
585642
self._max_retries = max_retries
586643

@@ -594,50 +651,19 @@ def use_grouping(self) -> bool:
594651
caps = self.capabilities()
595652
return not ("provides_jacobian" in caps and caps["provides_jacobian"])
596653

597-
@property
598-
def parallel(self):
599-
return self._parallel
600-
601-
def batch_execute(self, circuits, **run_kwargs):
602-
if not self._parallel:
603-
return super().batch_execute(circuits)
604-
605-
for circuit in circuits:
606-
self.check_validity(circuit.operations, circuit.observables)
607-
all_trainable = []
608-
braket_circuits = []
609-
for circuit in circuits:
610-
trainable = (
611-
BraketQubitDevice._get_trainable_parameters(circuit)
612-
if self._parametrize_differentiable
613-
else {}
614-
)
615-
all_trainable.append(trainable)
616-
braket_circuits.append(
617-
self._pl_to_braket_circuit(
618-
circuit,
619-
trainable_indices=frozenset(trainable.keys()),
620-
**run_kwargs,
621-
)
622-
)
623-
624-
batch_shots = 0 if self.analytic else self.shots
625-
654+
def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
626655
task_batch = self._device.run_batch(
627-
braket_circuits,
656+
batch_circuits,
628657
s3_destination_folder=self._s3_folder,
629658
shots=batch_shots,
630659
max_parallel=self._max_parallel,
631660
max_connections=self._max_connections,
632661
poll_timeout_seconds=self._poll_timeout_seconds,
633662
poll_interval_seconds=self._poll_interval_seconds,
634-
inputs=(
635-
[{f"p_{k}": v for k, v in trainable.items()} for trainable in all_trainable]
636-
if self._parametrize_differentiable
637-
else []
638-
),
663+
inputs=inputs,
639664
**self._run_kwargs,
640665
)
666+
641667
# Call results() to retrieve the Braket results in parallel.
642668
try:
643669
braket_results_batch = task_batch.results(
@@ -647,18 +673,9 @@ def batch_execute(self, circuits, **run_kwargs):
647673
# Update the tracker before raising an exception further if some circuits do not complete.
648674
finally:
649675
if self.tracker.active:
650-
for task in task_batch.tasks:
651-
tracking_data = self._tracking_data(task)
652-
self.tracker.update(**tracking_data)
653-
total_executions = len(task_batch.tasks) - len(task_batch.unsuccessful)
654-
total_shots = total_executions * batch_shots
655-
self.tracker.update(batches=1, executions=total_executions, shots=total_shots)
656-
self.tracker.record()
676+
self._update_tracker_for_batch(task_batch, batch_shots)
657677

658-
return [
659-
self._braket_to_pl_result(braket_result, circuit)
660-
for braket_result, circuit in zip(braket_results_batch, circuits)
661-
]
678+
return braket_results_batch
662679

663680
def _run_task(self, circuit, inputs=None):
664681
return self._device.run(
@@ -1012,6 +1029,24 @@ def __init__(
10121029
device = LocalSimulator(backend)
10131030
super().__init__(wires, device, shots=shots, **run_kwargs)
10141031

1032+
def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
1033+
task_batch = self._device.run_batch(
1034+
batch_circuits,
1035+
shots=batch_shots,
1036+
max_parallel=self._max_parallel,
1037+
inputs=inputs,
1038+
**self._run_kwargs,
1039+
)
1040+
1041+
# Should not need try-except here as this is a local sim.
1042+
braket_results_batch = task_batch.results()
1043+
1044+
# Update the tracker
1045+
if self.tracker.active:
1046+
self._update_tracker_for_batch(task_batch, batch_shots)
1047+
1048+
return braket_results_batch
1049+
10151050
def _run_task(self, circuit, inputs=None):
10161051
return self._device.run(
10171052
circuit,

0 commit comments

Comments
 (0)