|
31 | 31 | adjoint_measurements,
|
32 | 32 | adjoint_observables,
|
33 | 33 | decompose,
|
| 34 | + mid_circuit_measurements, |
34 | 35 | no_sampling,
|
35 | 36 | stopping_condition,
|
36 | 37 | stopping_condition_shots,
|
@@ -258,13 +259,12 @@ def test_preprocess(self, adjoint):
|
258 | 259 | expected_program.add_transform(validate_measurements, name=device.name)
|
259 | 260 | expected_program.add_transform(validate_observables, accepted_observables, name=device.name)
|
260 | 261 | expected_program.add_transform(validate_device_wires, device.wires, name=device.name)
|
261 |
| - expected_program.add_transform( |
262 |
| - qml.devices.preprocess.mid_circuit_measurements, device=device |
263 |
| - ) |
| 262 | + expected_program.add_transform(mid_circuit_measurements, device=device) |
264 | 263 | expected_program.add_transform(
|
265 | 264 | decompose,
|
266 | 265 | stopping_condition=stopping_condition,
|
267 | 266 | stopping_condition_shots=stopping_condition_shots,
|
| 267 | + skip_initial_state_prep=True, |
268 | 268 | name=device.name,
|
269 | 269 | )
|
270 | 270 | expected_program.add_transform(qml.transforms.broadcast_expand)
|
@@ -293,6 +293,63 @@ def test_preprocess(self, adjoint):
|
293 | 293 | actual_program, _ = device.preprocess(config)
|
294 | 294 | assert actual_program == expected_program
|
295 | 295 |
|
| 296 | + @pytest.mark.parametrize( |
| 297 | + "op, is_trainable", |
| 298 | + [ |
| 299 | + (qml.StatePrep([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), False), |
| 300 | + (qml.StatePrep(qml.numpy.array([1 / np.sqrt(2), 1 / np.sqrt(2)]), wires=0), True), |
| 301 | + (qml.StatePrep(np.array([1, 0]), wires=0), False), |
| 302 | + (qml.BasisState([1, 1], wires=[0, 1]), False), |
| 303 | + (qml.BasisState(qml.numpy.array([1, 1]), wires=[0, 1]), True), |
| 304 | + ], |
| 305 | + ) |
| 306 | + def test_preprocess_state_prep_first_op_decomposition(self, op, is_trainable): |
| 307 | + """Test that state prep ops in the beginning of a tape are decomposed with adjoint |
| 308 | + but not otherwise.""" |
| 309 | + tape = qml.tape.QuantumScript([op, qml.RX(1.23, wires=0)], [qml.expval(qml.PauliZ(0))]) |
| 310 | + device = LightningDevice(wires=3) |
| 311 | + |
| 312 | + if is_trainable: |
| 313 | + # Need to decompose twice as the state prep ops we use first decompose into a template |
| 314 | + decomp = op.decomposition()[0].decomposition() |
| 315 | + else: |
| 316 | + decomp = [op] |
| 317 | + |
| 318 | + config = ExecutionConfig(gradient_method="best" if is_trainable else None) |
| 319 | + program, _ = device.preprocess(config) |
| 320 | + [new_tape], _ = program([tape]) |
| 321 | + expected_tape = qml.tape.QuantumScript([*decomp, qml.RX(1.23, wires=0)], tape.measurements) |
| 322 | + assert qml.equal(new_tape, expected_tape) |
| 323 | + |
| 324 | + @pytest.mark.parametrize( |
| 325 | + "op, decomp_depth", |
| 326 | + [ |
| 327 | + (qml.StatePrep([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 1), |
| 328 | + (qml.StatePrep(np.array([1, 0]), wires=0), 1), |
| 329 | + (qml.BasisState([1, 1], wires=[0, 1]), 1), |
| 330 | + (qml.BasisState(qml.numpy.array([1, 1]), wires=[0, 1]), 1), |
| 331 | + (qml.AmplitudeEmbedding([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 2), |
| 332 | + (qml.MottonenStatePreparation([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 0), |
| 333 | + ], |
| 334 | + ) |
| 335 | + def test_preprocess_state_prep_middle_op_decomposition(self, op, decomp_depth): |
| 336 | + """Test that state prep ops in the middle of a tape are always decomposed.""" |
| 337 | + tape = qml.tape.QuantumScript( |
| 338 | + [qml.RX(1.23, wires=0), op, qml.CNOT([0, 1])], [qml.expval(qml.PauliZ(0))] |
| 339 | + ) |
| 340 | + device = LightningDevice(wires=3) |
| 341 | + |
| 342 | + for _ in range(decomp_depth): |
| 343 | + op = op.decomposition()[0] |
| 344 | + decomp = op.decomposition() |
| 345 | + |
| 346 | + program, _ = device.preprocess() |
| 347 | + [new_tape], _ = program([tape]) |
| 348 | + expected_tape = qml.tape.QuantumScript( |
| 349 | + [qml.RX(1.23, wires=0), *decomp, qml.CNOT([0, 1])], tape.measurements |
| 350 | + ) |
| 351 | + assert qml.equal(new_tape, expected_tape) |
| 352 | + |
296 | 353 | @pytest.mark.usefixtures("use_legacy_and_new_opmath")
|
297 | 354 | @pytest.mark.parametrize("theta, phi", list(zip(THETA, PHI)))
|
298 | 355 | @pytest.mark.parametrize(
|
|
0 commit comments