Skip to content

Commit 76ba4d4

Browse files
fixed catalyst to correctly extract number of shots for legacyDeviceFacade (#1035)
**Context:** This PR is to fix an issue that will be caused after merging the following PR: PennyLaneAI/pennylane#6046 The problem is that LegacyDeviceFacade will cause the Legacy Devices to use the Shots class instead of an integer which causes a crash in runtime. **Description of the Change:** This PR ensures that the any device that is an instance of the new device API uses the shots class to extract the total number of shots. **Related GitHub Issues:** [sc-70792] --------- Co-authored-by: Ahmed Darwish <[email protected]>
1 parent 68f579e commit 76ba4d4

10 files changed

+50
-23
lines changed

.dep-versions

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ enzyme=v0.0.130
88

99
# For a custom PL version, update the package version here and at
1010
# 'doc/requirements.txt
11-
pennylane=0.38.0.dev11
11+
pennylane=0.38.0.dev21
1212

1313
# For a custom LQ/LK version, update the package version here and at
1414
# 'doc/requirements.txt'. Also, update the 'LIGHTNING_GIT_TAG' at
1515
# 'runtime/Makefile' and at all GitHub workflows, using the exact
1616
# commit hash corresponding to the merged PR that implements the
1717
# desired feature.
18-
lightning=0.38.0-dev34
18+
lightning=0.38.0-dev50

doc/changelog.md

+3
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@
456456

457457
<h3>Bug fixes</h3>
458458

459+
* Fix a bug where LegacyDevice number of shots is not correctly extracted when using the legacyDeviceFacade.
460+
[(#1035)](https://github.com/PennyLaneAI/catalyst/pull/1035)
461+
459462
* Catalyst no longer generates a `QubitUnitary` operation during decomposition if a device doesn't
460463
support it. Instead, the operation that would lead to a `QubitUnitary` is either decomposed or
461464
raises an error.

doc/requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ lxml_html_clean
3030

3131
# Pre-install PL development wheels
3232
--extra-index-url https://test.pypi.org/simple/
33-
pennylane-lightning-kokkos==0.38.0-dev34
34-
pennylane-lightning==0.38.0-dev34
35-
pennylane==0.38.0.dev11
33+
pennylane-lightning-kokkos==0.38.0-dev50
34+
pennylane-lightning==0.38.0-dev50
35+
pennylane==0.38.0.dev21

frontend/catalyst/device/qjit_device.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
verify_operations,
4545
)
4646
from catalyst.logging import debug_logger, debug_logger_init
47+
from catalyst.third_party.cuda import SoftwareQQPP
4748
from catalyst.utils.exceptions import CompileError
4849
from catalyst.utils.patching import Patcher
4950
from catalyst.utils.runtime_environment import get_lib_path
@@ -133,7 +134,7 @@
133134

134135
def get_device_shots(dev):
135136
"""Helper function to get device shots."""
136-
return dev.shots if isinstance(dev, qml.devices.LegacyDevice) else dev.shots.total_shots
137+
return dev.shots.total_shots if isinstance(dev, qml.devices.Device) else dev.shots
137138

138139

139140
@dataclass
@@ -153,8 +154,8 @@ def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabiliti
153154
to a valid TOML config file."""
154155

155156
dname = device.name
156-
if isinstance(device, qml.devices.LegacyDevice):
157-
dname = device.short_name
157+
if isinstance(device, qml.devices.LegacyDeviceFacade):
158+
dname = device.target_device.short_name
158159

159160
device_name = ""
160161
device_lpath = ""
@@ -189,18 +190,18 @@ def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabiliti
189190
device_kwargs["device_type"] = dname
190191
device_kwargs["backend"] = (
191192
# pylint: disable=protected-access
192-
device._device._delegate.DEVICE_ID
193+
device.target_device._device._delegate.DEVICE_ID
193194
)
194195
elif dname == "braket.aws.qubit": # pragma: no cover
195196
device_kwargs["device_type"] = dname
196197
device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access
197-
if device._s3_folder: # pylint: disable=protected-access
198+
if device.target_device._s3_folder: # pylint: disable=protected-access
198199
device_kwargs["s3_destination_folder"] = str(
199-
device._s3_folder # pylint: disable=protected-access
200+
device.target_device._s3_folder # pylint: disable=protected-access
200201
)
201202

202203
for k, v in capabilities.options.items():
203-
if hasattr(device, v):
204+
if hasattr(device, v) and not k in device_kwargs:
204205
device_kwargs[k] = getattr(device, v)
205206

206207
return BackendInfo(dname, device_name, device_lpath, device_kwargs)
@@ -506,6 +507,8 @@ def preprocess(
506507
def _measurement_transform_program(self):
507508

508509
measurement_program = TransformProgram()
510+
if isinstance(self.original_device, SoftwareQQPP):
511+
return measurement_program
509512

510513
supports_sum_observables = any(
511514
obs in self.qjit_capabilities.native_obs for obs in ("Sum", "Hamiltonian")

frontend/catalyst/qfunc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ def processing_fn(results):
253253
total_shots = get_device_shots(dev)
254254

255255
new_dev = copy(dev)
256-
if isinstance(new_dev, qml.devices.LegacyDevice):
257-
new_dev.shots = 1 # pragma: no cover
256+
if isinstance(new_dev, qml.devices.LegacyDeviceFacade):
257+
new_dev.target_device.shots = 1 # pragma: no cover
258258
else:
259259
new_dev._shots = qml.measurements.Shots(1)
260260
single_shot_qnode.device = new_dev

frontend/catalyst/third_party/cuda/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ class BaseCudaInstructionSet(qml.QubitDevice):
127127
"SWAP",
128128
"CSWAP",
129129
]
130-
observables = []
130+
observables = [
131+
"PauliX",
132+
"PauliZ",
133+
]
131134
config = Path(__file__).parent / "cuda_quantum.toml"
132135

133136
def __init__(self, shots=None, wires=None):

frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -825,9 +825,16 @@ def cudaq_backend_info(device, _capabilities) -> BackendInfo:
825825
catalyst-specific. We need to make this API a bit nicer for third-party compilers.
826826
"""
827827
device_name = (
828-
device.short_name if isinstance(device, qml.devices.LegacyDevice) else device.name
828+
device.target_device.short_name
829+
if isinstance(device, qml.devices.LegacyDeviceFacade)
830+
else device.name
829831
)
830-
return BackendInfo(device_name, device.name, "", {})
832+
interface_name = (
833+
device.target_device.name
834+
if isinstance(device, qml.devices.LegacyDeviceFacade)
835+
else device.name
836+
)
837+
return BackendInfo(device_name, interface_name, "", {})
831838

832839
with Patcher(
833840
(QFunc, "extract_backend_info", cudaq_backend_info),

frontend/catalyst/third_party/cuda/cuda_quantum.toml

+8
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,17 @@ CSWAP = { properties = [ "invertible" ] }
3434

3535
# Observables supported natively by the device
3636
[operators.observables]
37+
PauliX = {}
38+
PauliZ = {}
39+
Sum = {}
3740

3841
[measurement_processes]
3942

43+
Expval = {}
44+
State = { condition = [ "analytic" ] }
45+
Sample = { condition = [ "finiteshots" ] }
46+
Counts = { condition = [ "finiteshots" ] }
47+
4048
[compilation]
4149
# If the device is compatible with qjit
4250
qjit_compatible = true

frontend/test/pytest/test_cuda_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_measurement_side_effect(self):
6868
from catalyst.third_party.cuda import cudaqjit as cjit
6969

7070
@cjit
71-
@qml.qnode(qml.device("softwareq.qpp", wires=1, shots=30))
71+
@qml.qnode(qml.device("softwareq.qpp", wires=1, shots=None))
7272
def circuit():
7373
qml.RX(jnp.pi / 4, wires=[0])
7474
measure(0)

frontend/test/pytest/test_measurements_shots_results.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,13 @@ def circuit(theta, phi, varphi):
428428
qml.CNOT(wires=[1, 2])
429429
return qml.var(0.2 * qml.PauliZ(wires=0) + 0.5 * qml.Hadamard(wires=1))
430430

431-
if isinstance(dev, qml.devices.Device):
431+
if isinstance(dev, qml.devices.LegacyDeviceFacade):
432+
with pytest.raises(
433+
RuntimeError,
434+
match=r"Cannot split up terms in sums for MeasurementProcess <class 'pennylane.measurements.var.VarianceMP'>",
435+
):
436+
circuit(0.432, 0.123, -0.543)
437+
else:
432438
# TODO: only raises with the new API, Kokkos should also raise an error.
433439
with pytest.raises(
434440
TypeError,
@@ -488,10 +494,7 @@ def test_missing_shots_value(self, backend, meas_fun):
488494
def circuit():
489495
return meas_fun(wires=0)
490496

491-
# ValueError is legacy behaviour with the old device API
492-
error_type = ValueError if isinstance(dev, qml.devices.LegacyDevice) else CompileError
493-
494-
with pytest.raises(error_type, match="cannot work with shots=None"):
497+
with pytest.raises(CompileError, match="cannot work with shots=None"):
495498
qjit(circuit)
496499

497500
def test_multiple_return_values(self, backend, tol_stochastic):

0 commit comments

Comments
 (0)