Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate BasePrimitiveResult #11054

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions qiskit/primitives/base/base_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@
from typing import Any, Dict

from numpy import ndarray
from qiskit.utils.deprecation import deprecate_func


ExperimentData = Dict[str, Any]


class BasePrimitiveResult(ABC):
"""Primitive result abstract base class.

Base class for Primitive results meant to provide common functionality to all inheriting
result dataclasses.
class _BasePrimitiveResult(ABC):
"""
Base class for deprecated Primitive result methods.
"""

def __post_init__(self) -> None:
Expand All @@ -45,22 +44,27 @@ def __post_init__(self) -> None:
TypeError: If one of the data fields is not a Sequence or ``numpy.ndarray``.
ValueError: Inconsistent number of experiments across data fields.
"""
num_experiments = None
for value in self._field_values: # type: Sequence
if num_experiments is None:
num_experiments = len(value)
# TODO: enforce all data fields to be tuples instead of sequences
if not isinstance(value, (Sequence, ndarray)) or isinstance(value, (str, bytes)):
raise TypeError(
f"Expected sequence or `numpy.ndarray`, provided {type(value)} instead."
)
if len(value) != self.num_experiments:
if len(value) != num_experiments:
raise ValueError("Inconsistent number of experiments across data fields.")

@property # TODO: functools.cached_property when py37 is droppped
@property # TODO: functools.cached_property when py37 is dropped
@deprecate_func(since="0.46.0", is_property=True)
def num_experiments(self) -> int:
"""Number of experiments in any inheriting result dataclass."""
value: Sequence = self._field_values[0]
return len(value)

@property # TODO: functools.cached_property when py37 is droppped
@property # TODO: functools.cached_property when py37 is dropped
@deprecate_func(since="0.46.0", is_property=True)
def experiments(self) -> tuple[ExperimentData, ...]:
"""Experiment data dicts in any inheriting result dataclass."""
return tuple(self._generate_experiments())
Expand All @@ -71,17 +75,36 @@ def _generate_experiments(self) -> Iterator[ExperimentData]:
for values in zip(*self._field_values):
yield dict(zip(names, values))

def decompose(self) -> Iterator[BasePrimitiveResult]:
@deprecate_func(since="0.46.0")
def decompose(self) -> Iterator[_BasePrimitiveResult]:
"""Generate single experiment result objects from self."""
for values in zip(*self._field_values):
yield self.__class__(*[(v,) for v in values])

@property # TODO: functools.cached_property when py37 is droppped
@property # TODO: functools.cached_property when py37 is dropped
def _field_names(self) -> tuple[str, ...]:
"""Tuple of field names in any inheriting result dataclass."""
return tuple(field.name for field in fields(self))

@property # TODO: functools.cached_property when py37 is droppped
@property # TODO: functools.cached_property when py37 is dropped
def _field_values(self) -> tuple:
"""Tuple of field values in any inheriting result dataclass."""
return tuple(getattr(self, name) for name in self._field_names)


# Deprecation warning for importing BasePrimitiveResult directly


def __getattr__(name):
if name == "BasePrimitiveResult":
import warnings

warnings.warn(
"The BasePrimitiveResult class is deprecated since Qiskit 0.46"
" and will be removed in Qiskit 1.0. Use EstimatorResult or SamplerResult"
" directly instead",
DeprecationWarning,
stacklevel=2,
)
return _BasePrimitiveResult
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
4 changes: 2 additions & 2 deletions qiskit/primitives/base/estimator_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from .base_result import BasePrimitiveResult
from .base_result import _BasePrimitiveResult

if TYPE_CHECKING:
import numpy as np


@dataclass(frozen=True)
class EstimatorResult(BasePrimitiveResult):
class EstimatorResult(_BasePrimitiveResult):
"""Result of Estimator.

.. code-block:: python
Expand Down
4 changes: 2 additions & 2 deletions qiskit/primitives/base/sampler_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

from qiskit.result import QuasiDistribution

from .base_result import BasePrimitiveResult
from .base_result import _BasePrimitiveResult


@dataclass(frozen=True)
class SamplerResult(BasePrimitiveResult):
class SamplerResult(_BasePrimitiveResult):
"""Result of Sampler.

.. code-block:: python
Expand Down
6 changes: 3 additions & 3 deletions qiskit/primitives/primitive_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Union

from qiskit.providers import JobError, JobStatus, JobV1

from .base.base_result import BasePrimitiveResult
from .base import EstimatorResult, SamplerResult

T = TypeVar("T", bound=BasePrimitiveResult)
T = TypeVar("T", bound=Union[SamplerResult, EstimatorResult])


class PrimitiveJob(JobV1, Generic[T]):
Expand Down
23 changes: 13 additions & 10 deletions test/python/primitives/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ddt import data, ddt, unpack

from qiskit.primitives.base.base_result import BasePrimitiveResult
from qiskit.primitives.base.base_result import _BasePrimitiveResult as BasePrimitiveResult
from qiskit.test import QiskitTestCase


Expand Down Expand Up @@ -57,28 +57,31 @@ def test_post_init_value_error(self, field_1, field_2):
def test_num_experiments(self, num_experiments):
"""Tests {num_experiments} num_experiments."""
result = Result([0] * num_experiments, [1] * num_experiments)
self.assertEqual(num_experiments, result.num_experiments)
with self.assertRaises(DeprecationWarning):
self.assertEqual(num_experiments, result.num_experiments)

@data(0, 1, 2, 3)
def test_experiments(self, num_experiments):
"""Test experiment data."""
field_1 = list(range(num_experiments))
field_2 = [i + 1 for i in range(num_experiments)]
experiments = Result(field_1, field_2).experiments
self.assertIsInstance(experiments, tuple)
for i, exp in enumerate(experiments):
self.assertEqual(exp, {"field_1": i, "field_2": i + 1})
with self.assertRaises(DeprecationWarning):
experiments = Result(field_1, field_2).experiments
self.assertIsInstance(experiments, tuple)
for i, exp in enumerate(experiments):
self.assertEqual(exp, {"field_1": i, "field_2": i + 1})

@data(0, 1, 2, 3)
def test_decompose(self, num_experiments):
"""Test decompose."""
field_1 = list(range(num_experiments))
field_2 = [i + 1 for i in range(num_experiments)]
result = Result(field_1, field_2)
for i, res in enumerate(result.decompose()):
self.assertIsInstance(res, Result)
f1, f2 = (i,), (i + 1,)
self.assertEqual(res, Result(f1, f2))
with self.assertRaises(DeprecationWarning):
for i, res in enumerate(result.decompose()):
self.assertIsInstance(res, Result)
f1, f2 = (i,), (i + 1,)
self.assertEqual(res, Result(f1, f2))

def test_field_names(self):
"""Tests field names ("field_1", "field_2")."""
Expand Down