Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MyPy support #41

Merged
merged 3 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy-mockito.*]
ignore_missing_imports = True
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Changelog
=========

Version 6.2
===========

* Add mypy support. `#41 <https://github.com/iqm-finland/iqm-client/pull/41>`_

Version 6.1
===========

Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ testing =
jsons==1.6.1
jsonschema==4.4.0
mockito==1.3.0
types-requests == 2.28.9
types-jsonschema == 4.14.0
cicd =
twine >= 3.3.0, < 4.0
wheel >= 0.36.2, < 1.0
39 changes: 23 additions & 16 deletions src/iqm_client/iqm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@
import requests
from pydantic import BaseModel, Field

REQUESTS_TIMEOUT = 10

DEFAULT_TIMEOUT_SECONDS = 900
SECONDS_BETWEEN_CALLS = 1
REFRESH_MARGIN_SECONDS = 5
Expand Down Expand Up @@ -498,6 +500,7 @@ def __del__(self):
except Exception: # pylint: disable=broad-except
pass

# pylint: disable=too-many-locals
def submit_circuits(
self,
circuits: list[Circuit],
Expand All @@ -521,6 +524,7 @@ def submit_circuits(
Returns:
ID for the created task. This ID is needed to query the status and the execution results.
"""
serialized_qubit_mapping: Optional[list[SingleQubitMapping]] = None
if qubit_mapping is not None:
# check if qubit mapping is injective
target_qubits = set(qubit_mapping.values())
Expand All @@ -541,13 +545,13 @@ def submit_circuits(
if diff:
raise ValueError(f'The physical qubits {diff} in the qubit mapping are not defined in settings.')

qubit_mapping = serialize_qubit_mapping(qubit_mapping)
serialized_qubit_mapping = serialize_qubit_mapping(qubit_mapping)

# ``bearer_token`` can be ``None`` if cocos we're connecting does not use authentication
bearer_token = self._get_bearer_token()

data = RunRequest(
qubit_mapping=qubit_mapping,
qubit_mapping=serialized_qubit_mapping,
circuits=circuits,
settings=settings,
calibration_set_id=calibration_set_id,
Expand All @@ -562,6 +566,7 @@ def submit_circuits(
join(self._base_url, 'jobs'),
json=data.dict(exclude_none=True),
headers=headers,
timeout=REQUESTS_TIMEOUT
)

if result.status_code == 401:
Expand All @@ -586,16 +591,17 @@ def get_run(self, job_id: UUID) -> RunResult:
bearer_token = self._get_bearer_token()
result = requests.get(
join(self._base_url, 'jobs/', str(job_id)),
headers=None if not bearer_token else {'Authorization': bearer_token}
headers=None if not bearer_token else {'Authorization': bearer_token},
timeout=REQUESTS_TIMEOUT
)
result.raise_for_status()
result = RunResult.from_dict(result.json())
if result.warnings:
for warning in result.warnings:
run_result = RunResult.from_dict(result.json())
if run_result.warnings:
for warning in run_result.warnings:
warnings.warn(warning)
if result.status == Status.FAILED:
raise CircuitExecutionError(result.message)
return result
if run_result.status == Status.FAILED:
raise CircuitExecutionError(run_result.message)
return run_result

def get_run_status(self, job_id: UUID) -> RunStatus:
"""Query the status of the running task.
Expand All @@ -613,14 +619,15 @@ def get_run_status(self, job_id: UUID) -> RunStatus:
bearer_token = self._get_bearer_token()
result = requests.get(
join(self._base_url, 'jobs/', str(job_id), 'status'),
headers=None if not bearer_token else {'Authorization': bearer_token}
headers=None if not bearer_token else {'Authorization': bearer_token},
timeout=REQUESTS_TIMEOUT
)
result.raise_for_status()
result = RunStatus.from_dict(result.json())
if result.warnings:
for warning in result.warnings:
run_result = RunStatus.from_dict(result.json())
if run_result.warnings:
for warning in run_result.warnings:
warnings.warn(warning)
return result
return run_result

def wait_for_results(self, job_id: UUID, timeout_secs: float = DEFAULT_TIMEOUT_SECONDS) -> RunResult:
"""Poll results until run is ready, failed, or timed out.
Expand Down Expand Up @@ -667,7 +674,7 @@ def close_auth_session(self) -> bool:

url = f'{self._credentials.auth_server_url}/realms/{AUTH_REALM}/protocol/openid-connect/logout'
data = AuthRequest(client_id=AUTH_CLIENT_ID, refresh_token=self._credentials.refresh_token)
result = requests.post(url, data=data.dict(exclude_none=True))
result = requests.post(url, data=data.dict(exclude_none=True), timeout=REQUESTS_TIMEOUT)
if result.status_code not in [200, 204]:
raise ClientAuthenticationError(f'Logout failed, {result.text}')
self._credentials.access_token = None
Expand Down Expand Up @@ -728,7 +735,7 @@ def _update_tokens(self):
)

url = f'{self._credentials.auth_server_url}/realms/{AUTH_REALM}/protocol/openid-connect/token'
result = requests.post(url, data=data.dict(exclude_none=True))
result = requests.post(url, data=data.dict(exclude_none=True), timeout=REQUESTS_TIMEOUT)
if result.status_code != 200:
raise ClientAuthenticationError(f'Failed to update tokens, {result.text}')
tokens = result.json()
Expand Down
Empty file added src/iqm_client/py.typed
Empty file.
10 changes: 7 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from iqm_client import AUTH_CLIENT_ID, AUTH_REALM, AuthRequest, GrantType

REQUESTS_TIMEOUT = 10

existing_run = UUID('3c3fcda3-e860-46bf-92a4-bcc59fa76ce9')
missing_run = UUID('059e4186-50a3-4e6c-ba1f-37fe6afbdfc2')

Expand Down Expand Up @@ -213,7 +215,8 @@ def prepare_tokens(
}
when(requests).post(
f'{credentials["auth_server_url"]}/realms/{AUTH_REALM}/protocol/openid-connect/token',
data=request_data.dict(exclude_none=True)
data=request_data.dict(exclude_none=True),
timeout=REQUESTS_TIMEOUT
).thenReturn(MockJsonResponse(status_code, tokens))

return tokens
Expand Down Expand Up @@ -249,7 +252,7 @@ def expect_status_request(url: str, access_token: Optional[str], times: int = 1)
"""
job_id = uuid4()
headers = None if access_token is None else {'Authorization': f'Bearer {access_token}'}
expect(requests, times=times).get(f'{url}/jobs/{job_id}', headers=headers).thenReturn(
expect(requests, times=times).get(f'{url}/jobs/{job_id}', headers=headers, timeout=REQUESTS_TIMEOUT).thenReturn(
MockJsonResponse(200, {'status': 'pending', 'metadata': {'shots': 42, 'circuits': []}})
)
return job_id
Expand All @@ -265,7 +268,8 @@ def expect_logout(auth_server_url: str, refresh_token: str):
request_data = AuthRequest(client_id=AUTH_CLIENT_ID, refresh_token=refresh_token)
expect(requests, times=1).post(
f'{auth_server_url}/realms/{AUTH_REALM}/protocol/openid-connect/logout',
data=request_data.dict(exclude_none=True)
data=request_data.dict(exclude_none=True),
timeout=REQUESTS_TIMEOUT
).thenReturn(
mock({'status_code': 204, 'text': '{}'})
)
4 changes: 3 additions & 1 deletion tests/test_iqm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
SingleQubitMapping, Status, serialize_qubit_mapping)
from tests.conftest import MockJsonResponse, existing_run, missing_run

REQUESTS_TIMEOUT = 10


def test_serialize_qubit_mapping():
qubit_mapping = {'Alice': 'QB1', 'Bob': 'qubit_3', 'Charlie': 'physical 0'}
Expand Down Expand Up @@ -170,7 +172,7 @@ def test_waiting_for_results(mock_server, base_url, settings_dict):
def test_user_warning_is_emitted_when_warnings_in_response(base_url, settings_dict, capsys):
client = IQMClient(base_url)
msg = 'This is a warning msg'
with when(requests).get(f'{base_url}/jobs/{existing_run}', headers=None).thenReturn(
with when(requests).get(f'{base_url}/jobs/{existing_run}', headers=None, timeout=REQUESTS_TIMEOUT).thenReturn(
MockJsonResponse(200, {'status': 'ready', 'warnings': [msg], 'metadata': {'shots': 42, 'circuits': []}})
):
with pytest.warns(UserWarning, match=msg):
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ extras =
testing
commands =
pytest tests --verbose --cov --cov-report term-missing --junitxml=test_report.xml --doctest-modules src
pytest --pylint src/
pytest --pylint tests/ --pylint-rcfile=tests/.pylintrc
pytest --mypy --pylint src/
pytest --mypy --pylint tests/ --pylint-rcfile=tests/.pylintrc
pytest --isort tests/ src/ --verbose

[testenv:docs]
Expand Down