Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defer op-arithmetic to default qubit #349

Merged
merged 5 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
* Implements caching for Kokkos installation.
[(#316)](https://github.com/PennyLaneAI/pennylane-lightning/pull/316)

* Supports measurements of operator arithmetic classes such as `Sum`, `Prod`,
and `SProd` by deferring handling of them to `DefaultQubit`.
[(#349)](https://github.com/PennyLaneAI/pennylane-lightning/pull/349)

```
@qml.qnode(qml.device('lightning.qubit', wires=2))
def circuit():
obs = qml.s_prod(2.1, qml.PauliZ(0)) + qml.op_sum(qml.PauliX(0), qml.PauliZ(1))
return qml.expval(obs)
```

### Documentation

### Bug fixes
Expand All @@ -32,7 +43,7 @@

This release contains contributions from (in alphabetical order):

Amintor Dusko, Chae-Yeun Park
Amintor Dusko, Christina Lee, Chae-Yeun Park

---

Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.26.0-dev11"
__version__ = "0.26.0-dev12"
22 changes: 9 additions & 13 deletions pennylane_lightning/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ def _remove_snapshot_from_operations(operations):
return operations


def _remove_op_arithmetic_from_observables(observables):
observables = observables.copy()
observables.discard("Sum")
observables.discard("SProd")
observables.discard("Prod")
return observables


class LightningQubit(DefaultQubit):
"""PennyLane Lightning device.

Expand Down Expand Up @@ -111,7 +103,6 @@ class LightningQubit(DefaultQubit):
author = "Xanadu Inc."
_CPP_BINARY_AVAILABLE = True
operations = _remove_snapshot_from_operations(DefaultQubit.operations)
observables = _remove_op_arithmetic_from_observables(DefaultQubit.observables)

def __init__(self, wires, *, c_dtype=np.complex128, shots=None, batch_obs=False):
if c_dtype is np.complex64:
Expand Down Expand Up @@ -617,10 +608,15 @@ def expval(self, observable, shot_range=None, bin_size=None):
Returns:
Expectation value of the observable
"""
if isinstance(observable.name, List) or observable.name in [
"Identity",
"Projector",
]:
if (
(observable.arithmetic_depth > 0)
or isinstance(observable.name, List)
or observable.name
in [
"Identity",
"Projector",
]
):
return super().expval(observable, shot_range=shot_range, bin_size=bin_size)

if self.shots is not None:
Expand Down
17 changes: 0 additions & 17 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,3 @@ def test_create_device_with_dtype(C):
def test_create_device_with_unsupported_dtype():
with pytest.raises(TypeError, match="Unsupported complex Type:"):
dev = qml.device("lightning.qubit", wires=1, c_dtype=np.complex256)


def test_no_op_arithmetic_support():
"""Test that lightning qubit explicitly does not support SProd, Prod, and Sum."""

dev = qml.device("lightning.qubit", wires=2)
for name in {"Prod", "SProd", "Sum"}:
assert name not in dev.operations

obs = qml.prod(qml.PauliX(0), qml.PauliY(1))

@qml.qnode(dev)
def circuit():
return qml.expval(obs)

with pytest.raises(qml.DeviceError, match=r"Observable Prod not supported on device"):
circuit()
66 changes: 66 additions & 0 deletions tests/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,72 @@ def test_hadamard_expectation(self, theta, phi, qubit_device_3_wires, tol):
assert np.allclose(res, expected, tol)


class TestExpOperatorArithmetic:
"""Test integration of lightning with SProd, Prod, and Sum."""

dev = qml.device("lightning.qubit", wires=2)

def test_sprod(self):
"""Test the `SProd` class with lightning qubit."""

@qml.qnode(self.dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.s_prod(0.5, qml.PauliZ(0)))

x = np.array(0.123)
res = circuit(x)
assert qml.math.allclose(res, 0.5 * np.cos(x))

def test_prod(self):
"""Test the `Prod` class with lightning qubit."""

@qml.qnode(self.dev)
def circuit(x):
qml.RX(x, wires=0)
qml.Hadamard(1)
qml.PauliZ(1)
return qml.expval(qml.prod(qml.PauliZ(0), qml.PauliX(1)))

x = np.array(0.123)
res = circuit(x)
assert qml.math.allclose(res, -np.cos(x))

def test_sum(self):
"""Test the `Sum` class with lightning qubit."""

@qml.qnode(self.dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
return qml.expval(qml.op_sum(qml.PauliZ(0), qml.PauliX(1)))

x = np.array(-3.21)
y = np.array(2.34)
res = circuit(x, y)
assert qml.math.allclose(res, np.cos(x) + np.sin(y))

def test_integration(self):
"""Test a Combination of `Sum`, `SProd`, and `Prod`."""

obs = qml.op_sum(
qml.s_prod(2.3, qml.PauliZ(0)), -0.5 * qml.prod(qml.PauliY(0), qml.PauliZ(1))
)

@qml.qnode(self.dev)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
return qml.expval(obs)

x = np.array(0.654)
y = np.array(-0.634)

res = circuit(x, y)
expected = 2.3 * np.cos(x) + 0.5 * np.sin(x) * np.cos(y)
assert qml.math.allclose(res, expected)


@pytest.mark.parametrize("theta,phi,varphi", list(zip(THETA, PHI, VARPHI)))
class TestTensorExpval:
"""Test tensor expectation values"""
Expand Down