From bb4751729e68393a5b88b652a846aec03c7a84bd Mon Sep 17 00:00:00 2001
From: Luca Antiga <luca.antiga@gmail.com>
Date: Fri, 7 Jul 2023 11:11:09 +0200
Subject: [PATCH] Add exponential backoff to HTTPQueue put (#18013)

---
 requirements/app/base.txt           |  1 +
 src/lightning/app/core/queues.py    |  2 ++
 tests/tests_app/core/test_queues.py | 14 +++++++++++---
 3 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/requirements/app/base.txt b/requirements/app/base.txt
index e9e1d284e1434..80607f79a91a2 100644
--- a/requirements/app/base.txt
+++ b/requirements/app/base.txt
@@ -13,6 +13,7 @@ inquirer >=2.10.0, <=3.1.3
 psutil <5.9.5
 click <=8.1.3
 python-multipart>=0.0.5, <=0.0.6
+backoff >=2.2.1, <2.3.0
 
 fastapi >=0.92.0, <0.100.0
 starlette  # https://fastapi.tiangolo.com/deployment/versions/#about-starlette
diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py
index 5b27b601fb5e5..18e02dd989b2e 100644
--- a/src/lightning/app/core/queues.py
+++ b/src/lightning/app/core/queues.py
@@ -23,6 +23,7 @@
 from typing import Any, Optional, Tuple
 from urllib.parse import urljoin
 
+import backoff
 import requests
 from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
 
@@ -431,6 +432,7 @@ def _get(self) -> Any:
             # we consider the queue is empty to avoid failing the app.
             raise queue.Empty
 
+    @backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))
     def put(self, item: Any) -> None:
         if not self.app_id:
             raise ValueError(f"The Lightning App ID couldn't be extracted from the queue name: {self.name}")
diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py
index d00432b734aad..8dd6d7d3a0b32 100644
--- a/tests/tests_app/core/test_queues.py
+++ b/tests/tests_app/core/test_queues.py
@@ -217,13 +217,21 @@ def test_http_queue_get(self, monkeypatch):
 
 def test_unreachable_queue(monkeypatch):
     monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
+
     test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
 
-    resp = mock.MagicMock()
-    resp.status_code = 204
+    resp1 = mock.MagicMock()
+    resp1.status_code = 204
+
+    resp2 = mock.MagicMock()
+    resp2.status_code = 201
 
     test_queue.client = mock.MagicMock()
-    test_queue.client.post.return_value = resp
+    test_queue.client.post = mock.Mock(side_effect=[resp1, resp1, resp2])
 
     with pytest.raises(queue.Empty):
         test_queue._get()
+
+    # Test backoff on queue.put
+    test_queue.put("foo")
+    assert test_queue.client.post.call_count == 3