51
51
from braket .devices import Device , LocalSimulator
52
52
from braket .simulator import BraketSimulator
53
53
from braket .tasks import GateModelQuantumTaskResult , QuantumTask
54
+ from braket .tasks .local_quantum_task_batch import LocalQuantumTaskBatch
54
55
from pennylane import QuantumFunctionError , QubitDevice
55
56
from pennylane import numpy as np
56
57
from pennylane .gradients import param_shift
@@ -105,6 +106,12 @@ class BraketQubitDevice(QubitDevice):
105
106
execution.
106
107
verbatim (bool): Whether to run tasks in verbatim mode. Note that verbatim mode only
107
108
supports the native gate set of the device. Default False.
109
+ parallel (bool): Whether to run tasks in parallel if supported by the device backend.
110
+ Default False.
111
+ max_parallel (int, optional): Maximum number of tasks to run on AWS in parallel.
112
+ Batch creation will fail if this value is greater than the maximum allowed concurrent
113
+ tasks on the device. If unspecified, uses defaults defined in ``AwsDevice``.
114
+ Ignored if ``parallel=False``.
108
115
parametrize_differentiable (bool): Whether to bind differentiable parameters (parameters
109
116
marked with ``required_grad=True``) on the Braket device rather than in PennyLane.
110
117
Default: True.
@@ -124,6 +131,8 @@ def __init__(
124
131
shots : Union [int , None ],
125
132
noise_model : Optional [NoiseModel ] = None ,
126
133
verbatim : bool = False ,
134
+ parallel : bool = False ,
135
+ max_parallel : Optional [int ] = None ,
127
136
parametrize_differentiable : bool = True ,
128
137
** run_kwargs ,
129
138
):
@@ -139,6 +148,8 @@ def __init__(
139
148
140
149
super ().__init__ (wires , shots = shots or None )
141
150
self ._device = device
151
+ self ._parallel = parallel
152
+ self ._max_parallel = max_parallel
142
153
self ._circuit = None
143
154
self ._task = None
144
155
self ._noise_model = noise_model
@@ -179,6 +190,49 @@ def task(self) -> QuantumTask:
179
190
"""QuantumTask: The task corresponding to the last run circuit."""
180
191
return self ._task
181
192
193
+ @property
194
+ def parallel (self ) -> bool :
195
+ """bool: Whether the device supports parallel execution of batches."""
196
+ return self ._parallel
197
+
198
+ def batch_execute (self , circuits , ** run_kwargs ):
199
+ if not self ._parallel :
200
+ return super ().batch_execute (circuits )
201
+
202
+ for circuit in circuits :
203
+ self .check_validity (circuit .operations , circuit .observables )
204
+ all_trainable = []
205
+ braket_circuits = []
206
+ for circuit in circuits :
207
+ trainable = (
208
+ BraketQubitDevice ._get_trainable_parameters (circuit )
209
+ if self ._parametrize_differentiable
210
+ else {}
211
+ )
212
+ all_trainable .append (trainable )
213
+ braket_circuits .append (
214
+ self ._pl_to_braket_circuit (
215
+ circuit ,
216
+ trainable_indices = frozenset (trainable .keys ()),
217
+ ** run_kwargs ,
218
+ )
219
+ )
220
+
221
+ batch_shots = 0 if self .analytic else self .shots
222
+
223
+ batch_inputs = (
224
+ [{f"p_{ k } " : v for k , v in trainable .items ()} for trainable in all_trainable ]
225
+ if self ._parametrize_differentiable
226
+ else []
227
+ )
228
+
229
+ braket_results_batch = self ._run_task_batch (braket_circuits , batch_shots , batch_inputs )
230
+
231
+ return [
232
+ self ._braket_to_pl_result (braket_result , circuit )
233
+ for braket_result , circuit in zip (braket_results_batch , circuits )
234
+ ]
235
+
182
236
def _pl_to_braket_circuit (
183
237
self ,
184
238
circuit : QuantumTape ,
@@ -242,6 +296,17 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
242
296
)
243
297
return braket_circuit
244
298
299
+ def _update_tracker_for_batch (
300
+ self , task_batch : Union [AwsQuantumTaskBatch , LocalQuantumTaskBatch ], batch_shots : int
301
+ ):
302
+ for task in task_batch .tasks :
303
+ tracking_data = self ._tracking_data (task )
304
+ self .tracker .update (** tracking_data )
305
+ total_executions = len (task_batch .tasks ) - len (task_batch .unsuccessful )
306
+ total_shots = total_executions * batch_shots
307
+ self .tracker .update (batches = 1 , executions = total_executions , shots = total_shots )
308
+ self .tracker .record ()
309
+
245
310
def statistics (
246
311
self , braket_result : GateModelQuantumTaskResult , measurements : Sequence [MeasurementProcess ]
247
312
) -> list [float ]:
@@ -525,10 +590,6 @@ class BraketAwsQubitDevice(BraketQubitDevice):
525
590
interactions with AWS services, to be supplied if extra control
526
591
is desired. Default: None
527
592
Default: False
528
- max_parallel (int, optional): Maximum number of tasks to run on AWS in parallel.
529
- Batch creation will fail if this value is greater than the maximum allowed concurrent
530
- tasks on the device. If unspecified, uses defaults defined in ``AwsDevice``.
531
- Ignored if ``parallel=False``.
532
593
max_connections (int): The maximum number of connections in the Boto3 connection pool.
533
594
Also the maximum number of thread pool workers for the batch.
534
595
Ignored if ``parallel=False``.
@@ -553,8 +614,6 @@ def __init__(
553
614
poll_timeout_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_TIMEOUT ,
554
615
poll_interval_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_INTERVAL ,
555
616
aws_session : Optional [AwsSession ] = None ,
556
- parallel : bool = False ,
557
- max_parallel : Optional [int ] = None ,
558
617
max_connections : int = AwsQuantumTaskBatch .MAX_CONNECTIONS_DEFAULT ,
559
618
max_retries : int = AwsQuantumTaskBatch .MAX_RETRIES ,
560
619
** run_kwargs ,
@@ -579,8 +638,6 @@ def __init__(
579
638
self ._s3_folder = s3_destination_folder
580
639
self ._poll_timeout_seconds = poll_timeout_seconds
581
640
self ._poll_interval_seconds = poll_interval_seconds
582
- self ._parallel = parallel
583
- self ._max_parallel = max_parallel
584
641
self ._max_connections = max_connections
585
642
self ._max_retries = max_retries
586
643
@@ -594,50 +651,19 @@ def use_grouping(self) -> bool:
594
651
caps = self .capabilities ()
595
652
return not ("provides_jacobian" in caps and caps ["provides_jacobian" ])
596
653
597
- @property
598
- def parallel (self ):
599
- return self ._parallel
600
-
601
- def batch_execute (self , circuits , ** run_kwargs ):
602
- if not self ._parallel :
603
- return super ().batch_execute (circuits )
604
-
605
- for circuit in circuits :
606
- self .check_validity (circuit .operations , circuit .observables )
607
- all_trainable = []
608
- braket_circuits = []
609
- for circuit in circuits :
610
- trainable = (
611
- BraketQubitDevice ._get_trainable_parameters (circuit )
612
- if self ._parametrize_differentiable
613
- else {}
614
- )
615
- all_trainable .append (trainable )
616
- braket_circuits .append (
617
- self ._pl_to_braket_circuit (
618
- circuit ,
619
- trainable_indices = frozenset (trainable .keys ()),
620
- ** run_kwargs ,
621
- )
622
- )
623
-
624
- batch_shots = 0 if self .analytic else self .shots
625
-
654
+ def _run_task_batch (self , batch_circuits , batch_shots : int , inputs ):
626
655
task_batch = self ._device .run_batch (
627
- braket_circuits ,
656
+ batch_circuits ,
628
657
s3_destination_folder = self ._s3_folder ,
629
658
shots = batch_shots ,
630
659
max_parallel = self ._max_parallel ,
631
660
max_connections = self ._max_connections ,
632
661
poll_timeout_seconds = self ._poll_timeout_seconds ,
633
662
poll_interval_seconds = self ._poll_interval_seconds ,
634
- inputs = (
635
- [{f"p_{ k } " : v for k , v in trainable .items ()} for trainable in all_trainable ]
636
- if self ._parametrize_differentiable
637
- else []
638
- ),
663
+ inputs = inputs ,
639
664
** self ._run_kwargs ,
640
665
)
666
+
641
667
# Call results() to retrieve the Braket results in parallel.
642
668
try :
643
669
braket_results_batch = task_batch .results (
@@ -647,18 +673,9 @@ def batch_execute(self, circuits, **run_kwargs):
647
673
# Update the tracker before raising an exception further if some circuits do not complete.
648
674
finally :
649
675
if self .tracker .active :
650
- for task in task_batch .tasks :
651
- tracking_data = self ._tracking_data (task )
652
- self .tracker .update (** tracking_data )
653
- total_executions = len (task_batch .tasks ) - len (task_batch .unsuccessful )
654
- total_shots = total_executions * batch_shots
655
- self .tracker .update (batches = 1 , executions = total_executions , shots = total_shots )
656
- self .tracker .record ()
676
+ self ._update_tracker_for_batch (task_batch , batch_shots )
657
677
658
- return [
659
- self ._braket_to_pl_result (braket_result , circuit )
660
- for braket_result , circuit in zip (braket_results_batch , circuits )
661
- ]
678
+ return braket_results_batch
662
679
663
680
def _run_task (self , circuit , inputs = None ):
664
681
return self ._device .run (
@@ -1012,6 +1029,24 @@ def __init__(
1012
1029
device = LocalSimulator (backend )
1013
1030
super ().__init__ (wires , device , shots = shots , ** run_kwargs )
1014
1031
1032
+ def _run_task_batch (self , batch_circuits , batch_shots : int , inputs ):
1033
+ task_batch = self ._device .run_batch (
1034
+ batch_circuits ,
1035
+ shots = batch_shots ,
1036
+ max_parallel = self ._max_parallel ,
1037
+ inputs = inputs ,
1038
+ ** self ._run_kwargs ,
1039
+ )
1040
+
1041
+ # Should not need try-except here as this is a local sim.
1042
+ braket_results_batch = task_batch .results ()
1043
+
1044
+ # Update the tracker
1045
+ if self .tracker .active :
1046
+ self ._update_tracker_for_batch (task_batch , batch_shots )
1047
+
1048
+ return braket_results_batch
1049
+
1015
1050
def _run_task (self , circuit , inputs = None ):
1016
1051
return self ._device .run (
1017
1052
circuit ,
0 commit comments