Skip to content

Commit

Permalink
Add exponential backoff to HTTPQueue put (#18013)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Jul 7, 2023
1 parent 2d5964d commit bb47517
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions requirements/app/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ inquirer >=2.10.0, <=3.1.3
psutil <5.9.5
click <=8.1.3
python-multipart>=0.0.5, <=0.0.6
backoff >=2.2.1, <2.3.0

fastapi >=0.92.0, <0.100.0
starlette # https://fastapi.tiangolo.com/deployment/versions/#about-starlette
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Optional, Tuple
from urllib.parse import urljoin

import backoff
import requests
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout

Expand Down Expand Up @@ -431,6 +432,7 @@ def _get(self) -> Any:
# we consider the queue is empty to avoid failing the app.
raise queue.Empty

@backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))
def put(self, item: Any) -> None:
if not self.app_id:
raise ValueError(f"The Lightning App ID couldn't be extracted from the queue name: {self.name}")
Expand Down
14 changes: 11 additions & 3 deletions tests/tests_app/core/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,21 @@ def test_http_queue_get(self, monkeypatch):

def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")

test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")

resp = mock.MagicMock()
resp.status_code = 204
resp1 = mock.MagicMock()
resp1.status_code = 204

resp2 = mock.MagicMock()
resp2.status_code = 201

test_queue.client = mock.MagicMock()
test_queue.client.post.return_value = resp
test_queue.client.post = mock.Mock(side_effect=[resp1, resp1, resp2])

with pytest.raises(queue.Empty):
test_queue._get()

# Test backoff on queue.put
test_queue.put("foo")
assert test_queue.client.post.call_count == 3

0 comments on commit bb47517

Please sign in to comment.