From 386985a0008882b73681118b0b7bdc59a8f13366 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Thu, 16 Jan 2025 13:31:06 -0800 Subject: [PATCH] fix: Flatten observable before getting targets (#287) Fixes #285 --- setup.py | 4 ++-- src/braket/pennylane_plugin/braket_device.py | 3 ++- src/braket/pennylane_plugin/translation.py | 8 ++++---- test/integ_tests/test_apply.py | 8 ++++---- test/unit_tests/test_braket_device.py | 6 ++++-- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index f44fa9a7..b586489f 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ install_requires=[ "amazon-braket-sdk>=1.87.0", "autoray>=0.6.11", - "pennylane>=0.34.0,<0.40", + "pennylane>=0.34.0", ], entry_points={ "pennylane.plugins": [ @@ -53,7 +53,7 @@ }, extras_require={ "test": [ - "autoray<0.7.0", # autoray.tensorflow_diag no longer works + "autoray<0.7.0", # autoray.tensorflow_diag no longer works "docutils>=0.19", "flaky", "pre-commit", diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index 01c183bd..6a3559ad 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -75,6 +75,7 @@ from braket.device_schema import DeviceActionType from braket.devices import Device, LocalSimulator from braket.pennylane_plugin.translation import ( + flatten_observable, get_adjoint_gradient_result_type, supported_observables, supported_operations, @@ -281,7 +282,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit): f" observable, not {len(circuit.observables)} observables." ) pl_measurements = circuit.measurements[0] - pl_observable = pl_measurements.obs + pl_observable = flatten_observable(pl_measurements.obs) if pl_measurements.return_type != Expectation: raise ValueError( f"Braket can only compute gradients for circuits with a single expectation" diff --git a/src/braket/pennylane_plugin/translation.py b/src/braket/pennylane_plugin/translation.py index 3587ddcb..7a9fcd61 100644 --- a/src/braket/pennylane_plugin/translation.py +++ b/src/braket/pennylane_plugin/translation.py @@ -546,7 +546,7 @@ def get_adjoint_gradient_result_type( if "AdjointGradient" not in supported_result_types: raise NotImplementedError("Unsupported return type: AdjointGradient") - braket_observable = _translate_observable(_flatten_observable(observable)) + braket_observable = _translate_observable(observable) braket_observable = ( braket_observable.item() if hasattr(braket_observable, "item") else braket_observable ) @@ -590,7 +590,7 @@ def translate_result_type( # noqa: C901 return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires) raise NotImplementedError(f"Unsupported return type: {return_type}") - observable = _flatten_observable(observable) + observable = flatten_observable(observable) if isinstance(observable, qml.ops.LinearCombination): if return_type is ObservableReturnTypes.Expectation: @@ -608,7 +608,7 @@ def translate_result_type( # noqa: C901 raise NotImplementedError(f"Unsupported return type: {return_type}") -def _flatten_observable(observable): +def flatten_observable(observable): if isinstance(observable, (qml.ops.CompositeOp, qml.ops.SProd)): simplified = qml.ops.LinearCombination(*observable.terms()).simplify() coeffs, _ = simplified.terms() @@ -735,7 +735,7 @@ def translate_result( return dict(braket_result.measurement_counts) translated = translate_result_type(measurement, targets, supported_result_types) - observable = _flatten_observable(observable) + observable = flatten_observable(observable) if isinstance(observable, qml.ops.LinearCombination): coeffs, _ = observable.terms() return sum( diff --git a/test/integ_tests/test_apply.py b/test/integ_tests/test_apply.py index 4986affc..65b9d5b8 100755 --- a/test/integ_tests/test_apply.py +++ b/test/integ_tests/test_apply.py @@ -85,7 +85,7 @@ def test_qubit_state_vector(self, init_state, device, tol): @qml.qnode(dev) def circuit(): - qml.QubitStateVector.compute_decomposition(state, wires=[0]) + qml.StatePrep.compute_decomposition(state, wires=[0]) return qml.probs(wires=range(1)) assert np.allclose(circuit(), np.abs(state) ** 2, **tol) @@ -177,7 +177,7 @@ def test_qubit_channel(self, init_state, dm_device, kraus, tol): def assert_op_and_inverse(op, dev, state, wires, tol, op_args): @qml.qnode(dev) def circuit(): - qml.QubitStateVector.compute_decomposition(state, wires=wires) + qml.StatePrep.compute_decomposition(state, wires=wires) op(*op_args, wires=wires) return qml.probs(wires=wires) @@ -185,7 +185,7 @@ def circuit(): @qml.qnode(dev) def circuit_inv(): - qml.QubitStateVector.compute_decomposition(state, wires=wires) + qml.StatePrep.compute_decomposition(state, wires=wires) qml.adjoint(op(*op_args, wires=wires)) return qml.probs(wires=wires) @@ -197,7 +197,7 @@ def circuit_inv(): def assert_noise_op(op, dev, state, wires, tol, op_args): @qml.qnode(dev) def circuit(): - qml.QubitStateVector.compute_decomposition(state, wires=wires) + qml.StatePrep.compute_decomposition(state, wires=wires) op(*op_args, wires=wires) return qml.probs(wires=wires) diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index 32612370..11c659a8 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -336,7 +336,10 @@ def test_execute_parametrize_differentiable(mock_run): qml.RY(0.543, wires=0), ], measurements=[ - qml.expval(2 * qml.PauliX(0) @ qml.PauliY(1) + 0.75 * qml.PauliY(0) @ qml.PauliZ(1)), + qml.expval( + 2 * qml.PauliX(0) @ qml.PauliY(1) @ qml.Identity(2) + + 0.75 * qml.PauliY(0) @ qml.PauliZ(1) + ), ], ) CIRCUIT_3.trainable_params = [0, 1] @@ -569,7 +572,6 @@ def test_execute_with_gradient_no_op_math( result_types, expected_pl_result, ): - task = Mock() type(task).id = PropertyMock(return_value="task_arn") task.state.return_value = "COMPLETED"