Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace LightningClient with import from lightning_cloud #18544

Merged
merged 7 commits into from
Sep 13, 2023
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lightning-cloud >=0.5.37
lightning-cloud >=0.5.38
packaging
typing-extensions >=4.0.0, <4.8.0
deepdiff >=5.7.0, <6.3.2
Expand Down
77 changes: 3 additions & 74 deletions src/lightning/app/utilities/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.

import socket
import time
from functools import wraps
from typing import Any, Callable, Dict, Optional
from urllib.parse import urljoin

import lightning_cloud
import requests
import urllib3
from lightning_cloud.rest_client import create_swagger_client, GridRestClient

# for backwards compatibility
from lightning_cloud.rest_client import create_swagger_client, GridRestClient, LightningClient # noqa: F401
from requests import Session
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
Expand Down Expand Up @@ -87,7 +86,6 @@ def _find_free_network_port_cloudspace():

_CONNECTION_RETRY_TOTAL = 2880
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
_DEFAULT_BACKOFF_MAX = 5 * 60 # seconds
_DEFAULT_REQUEST_TIMEOUT = 30 # seconds


Expand Down Expand Up @@ -119,75 +117,6 @@ def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bo
return False


def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> float:
next_backoff_value = backoff_value * (2 ** (num_retries - 1))
return min(_DEFAULT_BACKOFF_MAX, next_backoff_value)


def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable:
"""Returns the function decorated by a wrapper that retries the call several times if a connection error occurs.

The retries follow an exponential backoff.

"""

@wraps(func)
def wrapped(*args: Any, **kwargs: Any) -> Any:
consecutive_errors = 0

while True:
try:
return func(self, *args, **kwargs)
except (lightning_cloud.openapi.rest.ApiException, urllib3.exceptions.HTTPError) as ex:
# retry if the backend fails with all errors except 4xx but not 408 - (Request Timeout)
if (
isinstance(ex, urllib3.exceptions.HTTPError)
or ex.status in (408, 409)
or not str(ex.status).startswith("4")
):
consecutive_errors += 1
backoff_time = _get_next_backoff_time(consecutive_errors)

msg = (
f"error: {str(ex)}"
if isinstance(ex, urllib3.exceptions.HTTPError)
else f"response: {ex.status}"
)
logger.debug(
f"The {func.__name__} request failed to reach the server, {msg}."
f" Retrying after {backoff_time} seconds."
)

if max_tries is not None and consecutive_errors == max_tries:
raise Exception(f"The {func.__name__} request failed to reach the server, {msg}.")

time.sleep(backoff_time)
else:
raise ex

return wrapped


class LightningClient(GridRestClient):
"""The LightningClient is a wrapper around the GridRestClient.

It wraps all methods to monitor connection exceptions and employs a retry strategy.

Args:
retry: Whether API calls should follow a retry mechanism with exponential backoff.
max_tries: Maximum number of attempts (or -1 to retry forever).

"""

def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None:
super().__init__(api_client=create_swagger_client())
if retry:
for base_class in GridRestClient.__mro__:
for name, attribute in base_class.__dict__.items():
if callable(attribute) and attribute.__name__ != "__init__":
setattr(self, name, _retry_wrapper(self, attribute, max_tries=max_tries))


class CustomRetryAdapter(HTTPAdapter):
def __init__(self, *args: Any, **kwargs: Any):
self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
Expand Down
29 changes: 1 addition & 28 deletions tests/tests_app/utilities/test_network.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import re
from unittest import mock

import pytest
from urllib3.exceptions import HTTPError

from lightning.app.core import constants
from lightning.app.utilities.network import _retry_wrapper, find_free_network_port, LightningClient
from lightning.app.utilities.network import find_free_network_port


def test_find_free_network_port():
Expand Down Expand Up @@ -45,28 +43,3 @@ def test_find_free_network_port_cloudspace(_, patch_constants):

# Shouldn't use the APP_SERVER_PORT
assert constants.APP_SERVER_PORT not in ports


def test_lightning_client_retry_enabled():
client = LightningClient() # default: retry=True
assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")

client = LightningClient(retry=False)
assert not hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")

client = LightningClient(retry=True)
assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")


@mock.patch("time.sleep")
def test_retry_wrapper_max_tries(_):
mock_client = mock.MagicMock()
mock_client.test.__name__ = "test"
mock_client.test.side_effect = HTTPError("failed")

wrapped_mock_client = _retry_wrapper(mock_client, mock_client.test, max_tries=3)

with pytest.raises(Exception, match=re.escape("The test request failed to reach the server, error: failed")):
wrapped_mock_client()

assert mock_client.test.call_count == 3