14
14
"""
15
15
Class implementation for the Quimb MPS interface for simulating quantum circuits while keeping the state always in MPS form.
16
16
"""
17
+ import copy
18
+ from typing import Callable , Sequence , Union
17
19
18
- import numpy as np
19
20
import pennylane as qml
20
21
import quimb .tensor as qtn
22
+ from pennylane import numpy as np
23
+ from pennylane .devices import DefaultExecutionConfig , ExecutionConfig
24
+ from pennylane .devices .preprocess import (
25
+ decompose ,
26
+ validate_device_wires ,
27
+ validate_measurements ,
28
+ validate_observables ,
29
+ )
30
+ from pennylane .measurements import ExpectationMP , MeasurementProcess , StateMeasurement , VarianceMP
31
+ from pennylane .tape import QuantumScript , QuantumTape
32
+ from pennylane .transforms .core import TransformProgram
33
+ from pennylane .typing import Result , ResultBatch , TensorLike
21
34
from pennylane .wires import Wires
22
35
23
- _operations = frozenset ({}) # pragma: no cover
36
+ Result_or_ResultBatch = Union [Result , ResultBatch ]
37
+ QuantumTapeBatch = Sequence [QuantumTape ]
38
+ QuantumTape_or_Batch = Union [QuantumTape , QuantumTapeBatch ]
39
+ PostprocessingFn = Callable [[ResultBatch ], Result_or_ResultBatch ]
40
+
41
+ _operations = frozenset (
42
+ {
43
+ "Identity" ,
44
+ "QubitUnitary" ,
45
+ "ControlledQubitUnitary" ,
46
+ "MultiControlledX" ,
47
+ "DiagonalQubitUnitary" ,
48
+ "PauliX" ,
49
+ "PauliY" ,
50
+ "PauliZ" ,
51
+ "MultiRZ" ,
52
+ "GlobalPhase" ,
53
+ "Hadamard" ,
54
+ "S" ,
55
+ "T" ,
56
+ "SX" ,
57
+ "CNOT" ,
58
+ "SWAP" ,
59
+ "ISWAP" ,
60
+ "PSWAP" ,
61
+ "SISWAP" ,
62
+ "SQISW" ,
63
+ "CSWAP" ,
64
+ "Toffoli" ,
65
+ "CY" ,
66
+ "CZ" ,
67
+ "PhaseShift" ,
68
+ "ControlledPhaseShift" ,
69
+ "CPhase" ,
70
+ "RX" ,
71
+ "RY" ,
72
+ "RZ" ,
73
+ "Rot" ,
74
+ "CRX" ,
75
+ "CRY" ,
76
+ "CRZ" ,
77
+ "CRot" ,
78
+ "IsingXX" ,
79
+ "IsingYY" ,
80
+ "IsingZZ" ,
81
+ "IsingXY" ,
82
+ "SingleExcitation" ,
83
+ "SingleExcitationPlus" ,
84
+ "SingleExcitationMinus" ,
85
+ "DoubleExcitation" ,
86
+ "QubitCarry" ,
87
+ "QubitSum" ,
88
+ "OrbitalRotation" ,
89
+ "QFT" ,
90
+ "ECR" ,
91
+ "BlockEncode" ,
92
+ }
93
+ )
24
94
# The set of supported operations.
25
95
26
- _observables = frozenset ({}) # pragma: no cover
96
+
97
+ _observables = frozenset (
98
+ {
99
+ "PauliX" ,
100
+ "PauliY" ,
101
+ "PauliZ" ,
102
+ "Hadamard" ,
103
+ "Hermitian" ,
104
+ "Identity" ,
105
+ "Projector" ,
106
+ "SparseHamiltonian" ,
107
+ "Hamiltonian" ,
108
+ "LinearCombination" ,
109
+ "Sum" ,
110
+ "SProd" ,
111
+ "Prod" ,
112
+ "Exp" ,
113
+ }
114
+ )
27
115
# The set of supported observables.
28
116
29
117
@@ -57,15 +145,15 @@ def __init__(self, num_wires, interf_opts, dtype=np.complex128):
57
145
self ._wires = Wires (range (num_wires ))
58
146
self ._dtype = dtype
59
147
60
- self ._init_state_ops = {
148
+ self ._init_state_opts = {
61
149
"binary" : "0" * max (1 , len (self ._wires )),
62
150
"dtype" : self ._dtype .__name__ ,
63
151
"tags" : [str (l ) for l in self ._wires .labels ],
64
152
}
65
153
66
154
self ._gate_opts = {
67
- "contract" : "swap+split" ,
68
155
"parametrize" : None ,
156
+ "contract" : interf_opts ["contract" ],
69
157
"cutoff" : interf_opts ["cutoff" ],
70
158
"max_bond" : interf_opts ["max_bond_dim" ],
71
159
}
@@ -79,22 +167,205 @@ def __init__(self, num_wires, interf_opts, dtype=np.complex128):
79
167
self ._circuitMPS = qtn .CircuitMPS (psi0 = self ._initial_mps ())
80
168
81
169
@property
82
- def state (self ):
83
- """Current MPS handled by the interface."""
170
+ def interface_name (self ) -> str :
171
+ """The name of this interface."""
172
+ return "QuimbMPS interface"
173
+
174
+ @property
175
+ def state (self ) -> qtn .MatrixProductState :
176
+ """Return the current MPS handled by the interface."""
84
177
return self ._circuitMPS .psi
85
178
86
179
def state_to_array (self ) -> np .ndarray :
87
180
"""Contract the MPS into a dense array."""
88
181
return self ._circuitMPS .to_dense ()
89
182
183
+ def _reset_state (self ) -> None :
184
+ """Reset the MPS."""
185
+ self ._circuitMPS = qtn .CircuitMPS (psi0 = self ._initial_mps ())
186
+
90
187
def _initial_mps (self ) -> qtn .MatrixProductState :
91
188
r"""
92
- Returns an initial state to :math:`\ket{0}`.
189
+ Return an initial state to :math:`\ket{0}`.
93
190
94
191
Internally, it uses `quimb`'s `MPS_computational_state` method.
95
192
96
193
Returns:
97
194
MatrixProductState: The initial MPS of a circuit.
98
195
"""
196
+ return qtn .MPS_computational_state (** self ._init_state_opts )
197
+
198
+ def preprocess (self ) -> TransformProgram :
199
+ """This function defines the device transform program to be applied for this interface.
200
+
201
+ Returns:
202
+ TransformProgram: A transform program that when called returns :class:`~.QuantumTape`'s that the
203
+ device can natively execute as well as a postprocessing function to be called after execution.
204
+
205
+ This interface:
206
+
207
+ * Supports any one or two-qubit operations that provide a matrix.
208
+ * Supports any three or four-qubit operations that provide a decomposition method.
209
+ * Currently does not support finite shots.
210
+ """
211
+
212
+ program = TransformProgram ()
213
+
214
+ program .add_transform (validate_measurements , name = self .interface_name )
215
+ program .add_transform (validate_observables , accepted_observables , name = self .interface_name )
216
+ program .add_transform (validate_device_wires , self ._wires , name = self .interface_name )
217
+ program .add_transform (
218
+ decompose ,
219
+ stopping_condition = stopping_condition ,
220
+ skip_initial_state_prep = True ,
221
+ name = self .interface_name ,
222
+ )
223
+ program .add_transform (qml .transforms .broadcast_expand )
224
+
225
+ return program
226
+
227
+ # pylint: disable=unused-argument
228
+ def execute (
229
+ self ,
230
+ circuits : QuantumTape_or_Batch ,
231
+ execution_config : ExecutionConfig = DefaultExecutionConfig ,
232
+ ) -> Result_or_ResultBatch :
233
+ """Execute a circuit or a batch of circuits and turn it into results.
234
+
235
+ Args:
236
+ circuits (Union[QuantumTape, Sequence[QuantumTape]]): the quantum circuits to be executed
237
+ execution_config (ExecutionConfig): a datastructure with additional information required for execution
238
+
239
+ Returns:
240
+ TensorLike, tuple[TensorLike], tuple[tuple[TensorLike]]: A numeric result of the computation.
241
+ """
242
+
243
+ results = []
244
+ for circuit in circuits :
245
+ circuit = circuit .map_to_standard_wires ()
246
+ results .append (self .simulate (circuit ))
247
+
248
+ return tuple (results )
249
+
250
+ def simulate (self , circuit : QuantumScript ) -> Result :
251
+ """Simulate a single quantum script. This function assumes that all operations provide matrices.
252
+
253
+ Args:
254
+ circuit (QuantumScript): The single circuit to simulate.
255
+
256
+ Returns:
257
+ Tuple[TensorLike]: The results of the simulation.
258
+ """
259
+
260
+ self ._reset_state ()
261
+
262
+ for op in circuit .operations :
263
+ self ._apply_operation (op )
264
+
265
+ if not circuit .shots :
266
+ if len (circuit .measurements ) == 1 :
267
+ return self .measurement (circuit .measurements [0 ])
268
+ return tuple (self .measurement (mp ) for mp in circuit .measurements )
269
+
270
+ raise NotImplementedError # pragma: no cover
271
+
272
+ def _apply_operation (self , op : qml .operation .Operator ) -> None :
273
+ """Apply a single operator to the circuit, keeping the state always in a MPS form.
274
+
275
+ Internally it uses `quimb`'s `apply_gate` method.
276
+
277
+ Args:
278
+ op (Operator): The operation to apply.
279
+ """
280
+
281
+ self ._circuitMPS .apply_gate (op .matrix ().astype (self ._dtype ), * op .wires , ** self ._gate_opts )
282
+
283
+ def measurement (self , measurementprocess : MeasurementProcess ) -> TensorLike :
284
+ """Measure the measurement required by the circuit over the MPS.
285
+
286
+ Args:
287
+ measurementprocess (MeasurementProcess): measurement to apply to the state.
288
+
289
+ Returns:
290
+ TensorLike: the result of the measurement.
291
+ """
292
+
293
+ return self ._get_measurement_function (measurementprocess )(measurementprocess )
294
+
295
+ def _get_measurement_function (
296
+ self , measurementprocess : MeasurementProcess
297
+ ) -> Callable [[MeasurementProcess , TensorLike ], TensorLike ]:
298
+ """Get the appropriate method for performing a measurement.
299
+
300
+ Args:
301
+ measurementprocess (MeasurementProcess): measurement process to apply to the state
302
+
303
+ Returns:
304
+ Callable: function that returns the measurement result
305
+ """
306
+ if isinstance (measurementprocess , StateMeasurement ):
307
+ if isinstance (measurementprocess , ExpectationMP ):
308
+ return self .expval
309
+
310
+ if isinstance (measurementprocess , VarianceMP ):
311
+ return self .var
312
+
313
+ raise NotImplementedError # pragma: no cover
314
+
315
+ def expval (self , measurementprocess : MeasurementProcess ) -> float :
316
+ """Expectation value of the supplied observable contained in the MeasurementProcess.
317
+
318
+ Args:
319
+ measurementprocess (StateMeasurement): measurement to apply to the MPS.
320
+
321
+ Returns:
322
+ Expectation value of the observable.
323
+ """
324
+
325
+ obs = measurementprocess .obs
326
+
327
+ result = self ._local_expectation (obs .matrix (), tuple (obs .wires ))
328
+
329
+ return result
330
+
331
+ def var (self , measurementprocess : MeasurementProcess ) -> float :
332
+ """Variance of the supplied observable contained in the MeasurementProcess.
333
+
334
+ Args:
335
+ measurementprocess (StateMeasurement): measurement to apply to the MPS.
336
+
337
+ Returns:
338
+ Variance of the observable.
339
+ """
340
+
341
+ obs = measurementprocess .obs
342
+
343
+ obs_mat = obs .matrix ()
344
+ expect_op = self .expval (measurementprocess )
345
+ expect_squar_op = self ._local_expectation (obs_mat @ obs_mat .conj ().T , tuple (obs .wires ))
346
+
347
+ return expect_squar_op - np .square (expect_op )
348
+
349
+ def _local_expectation (self , matrix , wires ) -> float :
350
+ """Compute the local expectation value of a matrix on the MPS.
351
+
352
+ Internally, it uses `quimb`'s `local_expectation` method.
353
+
354
+ Args:
355
+ matrix (array): the matrix to compute the expectation value of.
356
+ wires (tuple[int]): the wires the matrix acts on.
357
+
358
+ Returns:
359
+ Local expectation value of the matrix on the MPS.
360
+ """
361
+
362
+ # We need to copy the MPS to avoid modifying the original state
363
+ qc = copy .deepcopy (self ._circuitMPS )
364
+
365
+ exp_val = qc .local_expectation (
366
+ matrix ,
367
+ wires ,
368
+ ** self ._expval_opts ,
369
+ )
99
370
100
- return qtn . MPS_computational_state ( ** self . _init_state_ops )
371
+ return float ( np . real ( exp_val ) )
0 commit comments