From ac3f1ee0d32b11feebb5a051d64ed8eea99484ec Mon Sep 17 00:00:00 2001
From: awaelchli <aedu.waelchli@gmail.com>
Date: Wed, 22 May 2024 19:19:39 +0200
Subject: [PATCH] Patch release 2.2.5 (#19893)

Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
---
 docs/source-app/conf.py                       |  3 ++
 docs/source-pytorch/conf.py                   |  3 +-
 requirements/app/app.txt                      |  2 +-
 src/lightning/app/core/app.py                 | 12 ++++---
 src/lightning/app/core/constants.py           |  2 ++
 src/lightning/app/runners/cloud.py            |  4 +--
 src/lightning/app/runners/multiprocess.py     | 21 ++++++------
 src/lightning/app/utilities/network.py        | 10 ++++--
 src/lightning/fabric/CHANGELOG.md             |  7 ++++
 .../fabric/plugins/precision/bitsandbytes.py  | 12 +++----
 src/lightning/pytorch/CHANGELOG.md            |  7 ++++
 src/version.info                              |  2 +-
 tests/tests_app/core/test_lightning_app.py    | 24 +++++++++++++
 tests/tests_app/runners/test_cloud.py         | 10 +++---
 tests/tests_app/utilities/test_network.py     |  8 +++--
 .../plugins/precision/test_bitsandbytes.py    | 34 +++++++++++++++++++
 16 files changed, 125 insertions(+), 36 deletions(-)

diff --git a/docs/source-app/conf.py b/docs/source-app/conf.py
index b3e384f2a05dd..fc9efa840ef4c 100644
--- a/docs/source-app/conf.py
+++ b/docs/source-app/conf.py
@@ -449,3 +449,6 @@ def find_source():
 
 # ignore all links in any CHANGELOG file
 linkcheck_exclude_documents = [r"^(.*\/)*CHANGELOG.*$"]
+
+# ignore the following relative links (false positive errors during linkcheck)
+linkcheck_ignore = ["https://openai.com/"]
diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py
index fa8eb179ead3b..59780dc2e8d95 100644
--- a/docs/source-pytorch/conf.py
+++ b/docs/source-pytorch/conf.py
@@ -343,8 +343,6 @@ def _load_py_module(name: str, location: str) -> ModuleType:
     "graphcore": ("https://docs.graphcore.ai/en/latest/", None),
     "lightning_habana": ("https://lightning-ai.github.io/lightning-Habana/", None),
     "tensorboardX": ("https://tensorboardx.readthedocs.io/en/stable/", None),
-    # needed for referencing App from lightning scope
-    "lightning.app": ("https://lightning.ai/docs/app/stable/", None),
     # needed for referencing Fabric from lightning scope
     "lightning.fabric": ("https://lightning.ai/docs/fabric/stable/", None),
     # TODO: these are missing objects.inv
@@ -626,4 +624,5 @@ def package_list_from_file(file):
     "https://stackoverflow.com/questions/66640705/how-can-i-install-grpcio-on-an-apple-m1-silicon-laptop",
     "https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/ipu/mnist_sample.py",
     "https://ngc.nvidia.com/catalog/containers/nvidia:nemo",  # in ecosystem/asr_nlp_tts.rst
+    "https://openai.com/",
 ]
diff --git a/requirements/app/app.txt b/requirements/app/app.txt
index 8cfc6e58301d2..2b4ec3c1517ec 100644
--- a/requirements/app/app.txt
+++ b/requirements/app/app.txt
@@ -1,4 +1,4 @@
-lightning-cloud == 0.5.68  # Must be pinned to ensure compatibility
+lightning-cloud == 0.5.69  # Must be pinned to ensure compatibility
 packaging
 typing-extensions >=4.4.0, <4.10.0
 deepdiff >=5.7.0, <6.6.0
diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py
index b9ff54f9a8852..c29a43ba9db0a 100644
--- a/src/lightning/app/core/app.py
+++ b/src/lightning/app/core/app.py
@@ -30,6 +30,7 @@
 from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
 from lightning.app.core.constants import (
     BATCH_DELTA_COUNT,
+    CHECK_ERROR_QUEUE_INTERVAL,
     DEBUG_ENABLED,
     FLOW_DURATION_SAMPLES,
     FLOW_DURATION_THRESHOLD,
@@ -165,6 +166,7 @@ def __init__(
 
         self._last_run_time: float = 0.0
         self._run_times: list = []
+        self._last_check_error_queue: float = 0.0
 
         # Path attributes can't get properly attached during the initialization, because the full name
         # is only available after all Flows and Works have been instantiated.
@@ -318,10 +320,12 @@ def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] =
             return []
 
     def check_error_queue(self) -> None:
-        exception: Exception = self.get_state_changed_from_queue(self.error_queue)  # type: ignore[assignment,arg-type]
-        if isinstance(exception, Exception):
-            self.exception = exception
-            self.stage = AppStage.FAILED
+        if (time() - self._last_check_error_queue) > CHECK_ERROR_QUEUE_INTERVAL:
+            exception: Exception = self.get_state_changed_from_queue(self.error_queue)  # type: ignore[assignment,arg-type]
+            if isinstance(exception, Exception):
+                self.exception = exception
+                self.stage = AppStage.FAILED
+            self._last_check_error_queue = time()
 
     @property
     def flows(self) -> List[Union[LightningWork, "LightningFlow"]]:
diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py
index 566fc87bc9438..4a06f66e5a474 100644
--- a/src/lightning/app/core/constants.py
+++ b/src/lightning/app/core/constants.py
@@ -70,6 +70,7 @@ def get_lightning_cloud_url() -> str:
 LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
 LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
 LIGHTNING_MODELS_PUBLIC_REGISTRY = "https://lightning.ai/v1/models"
+ENABLE_ORCHESTRATOR = bool(int(os.getenv("ENABLE_ORCHESTRATOR", "1")))
 
 LIGHTNING_CLOUDSPACE_HOST = os.getenv("LIGHTNING_CLOUDSPACE_HOST")
 LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT = int(os.getenv("LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT", "0"))
@@ -99,6 +100,7 @@ def get_lightning_cloud_url() -> str:
 SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"
 
 BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128"))
+CHECK_ERROR_QUEUE_INTERVAL = float(os.getenv("CHECK_ERROR_QUEUE_INTERVAL", "30"))
 
 
 def enable_multiple_works_in_default_container() -> bool:
diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py
index c488014450b9b..80fb03499e678 100644
--- a/src/lightning/app/runners/cloud.py
+++ b/src/lightning/app/runners/cloud.py
@@ -34,7 +34,7 @@
     CloudspaceIdRunsBody,
     Externalv1LightningappInstance,
     Gridv1ImageSpec,
-    IdGetBody1,
+    IdGetBody,
     ProjectIdCloudspacesBody,
     V1BuildSpec,
     V1CloudSpace,
@@ -1027,7 +1027,7 @@ def _api_create_run_instance(
             project_id=project_id,
             cloudspace_id=cloudspace_id,
             id=run_id,
-            body=IdGetBody1(
+            body=IdGetBody(
                 cluster_id=cluster_id,
                 name=run_name,
                 desired_state=desired_state,
diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py
index c3217197a6a33..94d627e95fc7b 100644
--- a/src/lightning/app/runners/multiprocess.py
+++ b/src/lightning/app/runners/multiprocess.py
@@ -81,16 +81,17 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
 
             _set_flow_context()
 
-            storage_orchestrator = StorageOrchestrator(
-                self.app,
-                self.app.request_queues,
-                self.app.response_queues,
-                self.app.copy_request_queues,
-                self.app.copy_response_queues,
-            )
-            self.threads.append(storage_orchestrator)
-            storage_orchestrator.setDaemon(True)
-            storage_orchestrator.start()
+            if constants.ENABLE_ORCHESTRATOR:
+                storage_orchestrator = StorageOrchestrator(
+                    self.app,
+                    self.app.request_queues,
+                    self.app.response_queues,
+                    self.app.copy_request_queues,
+                    self.app.copy_response_queues,
+                )
+                self.threads.append(storage_orchestrator)
+                storage_orchestrator.setDaemon(True)
+                storage_orchestrator.start()
 
             if self.start_server:
                 self.app.should_publish_changes_to_api = True
diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py
index 04afdb0b4f92c..a7cc00fde52b7 100644
--- a/src/lightning/app/utilities/network.py
+++ b/src/lightning/app/utilities/network.py
@@ -96,10 +96,14 @@ def create_retry_strategy():
         # are going to be alive for a very long time (~ 4 days) but retries every 120 seconds
         total=_CONNECTION_RETRY_TOTAL,
         backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
+        # Any 4xx and 5xx statuses except
+        # 400 Bad Request
+        # 401 Unauthorized
+        # 403 Forbidden
+        # 404 Not Found
         status_forcelist={
-            408,  # Request Timeout
-            429,  # Too Many Requests
-            *range(500, 600),  # Any 5xx Server Error status
+            402,
+            *range(405, 600),
         },
         allowed_methods={
             "POST",  # Default methods are idempotent, add POST here
diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md
index 987d547af7972..bbd59f470220e 100644
--- a/src/lightning/fabric/CHANGELOG.md
+++ b/src/lightning/fabric/CHANGELOG.md
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 
 
+## [2.2.5] - 2024-05-23
+
+### Fixed
+
+- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))
+
+
 ## [2.2.2] - 2024-04-11
 
 ### Fixed
diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py
index 9d4d5ff038a18..b78baeb42c2fe 100644
--- a/src/lightning/fabric/plugins/precision/bitsandbytes.py
+++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py
@@ -234,9 +234,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
             """Inplace quantize."""
             if weight is None:
                 weight = self.weight.data
-                if weight.data.type == torch.int8:
-                    # already quantized
-                    return
+            if weight.data.dtype == torch.int8:
+                # already quantized
+                return
             assert isinstance(self.weight, bnb.nn.Int8Params)
             self.weight = self.quantize(self.weight, weight, device)
 
@@ -318,9 +318,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
             """Inplace quantize."""
             if weight is None:
                 weight = self.weight.data
-                if weight.data.type == torch.uint8:
-                    # already quantized
-                    return
+            if weight.data.dtype == torch.uint8:
+                # already quantized
+                return
             assert isinstance(self.weight, bnb.nn.Params4bit)
             self.weight = self.quantize(self.weight, weight, device)
 
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index 691c34c563d7f..336938073e0f1 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 
 
+## [2.2.5] - 2024-05-23
+
+### Fixed
+
+- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))
+
+
 ## [2.2.3] - 2024-04-23
 
 ### Fixed
diff --git a/src/version.info b/src/version.info
index 530cdd91a205a..21bb5e156fbe2 100644
--- a/src/version.info
+++ b/src/version.info
@@ -1 +1 @@
-2.2.4
+2.2.5
diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py
index 452e66793cf18..a11481986f25f 100644
--- a/tests/tests_app/core/test_lightning_app.py
+++ b/tests/tests_app/core/test_lightning_app.py
@@ -1187,3 +1187,27 @@ def run(self):
 def test_lightning_work_stopped():
     app = LightningApp(SimpleWork2())
     MultiProcessRuntime(app, start_server=False).dispatch()
+
+
+class FailedWork(LightningWork):
+    def run(self):
+        raise Exception
+
+
+class CheckErrorQueueLightningApp(LightningApp):
+    def check_error_queue(self):
+        super().check_error_queue()
+
+
+def test_error_queue_check(monkeypatch):
+    import sys
+
+    from lightning.app.core import app as app_module
+
+    sys_mock = mock.MagicMock()
+    monkeypatch.setattr(app_module, "CHECK_ERROR_QUEUE_INTERVAL", 0)
+    monkeypatch.setattr(sys, "exit", sys_mock)
+    app = LightningApp(FailedWork())
+    MultiProcessRuntime(app, start_server=False).dispatch()
+    assert app.stage == AppStage.FAILED
+    assert app._last_check_error_queue != 0.0
diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py
index 74b74c99a8049..5f397284ebeaa 100644
--- a/tests/tests_app/runners/test_cloud.py
+++ b/tests/tests_app/runners/test_cloud.py
@@ -24,7 +24,7 @@
     Externalv1Cluster,
     Externalv1LightningappInstance,
     Gridv1ImageSpec,
-    IdGetBody1,
+    IdGetBody,
     ProjectIdProjectclustersbindingsBody,
     V1BuildSpec,
     V1CloudSpace,
@@ -508,7 +508,7 @@ def test_basic_auth_enabled(self, tmpdir, monkeypatch):
             project_id="test-project-id",
             cloudspace_id=mock.ANY,
             id=mock.ANY,
-            body=IdGetBody1(
+            body=IdGetBody(
                 desired_state=mock.ANY,
                 name=mock.ANY,
                 env=mock.ANY,
@@ -712,7 +712,7 @@ def test_call_with_queue_server_type_specified(self, tmpdir, lightningapps, monk
         cloud_runtime.dispatch()
 
         # calling with no env variable set
-        body = IdGetBody1(
+        body = IdGetBody(
             desired_state=V1LightningappInstanceState.STOPPED,
             env=[],
             name=mock.ANY,
@@ -727,7 +727,7 @@ def test_call_with_queue_server_type_specified(self, tmpdir, lightningapps, monk
         monkeypatch.setitem(os.environ, "LIGHTNING_CLOUD_QUEUE_TYPE", "http")
         cloud_runtime.backend.client.reset_mock()
         cloud_runtime.dispatch()
-        body = IdGetBody1(
+        body = IdGetBody(
             desired_state=V1LightningappInstanceState.STOPPED,
             env=mock.ANY,
             name=mock.ANY,
@@ -998,7 +998,7 @@ def test_call_with_work_app_and_app_comment_command_execution_set(self, lightnin
                 project_id="test-project-id",
                 cloudspace_id=mock.ANY,
                 id=mock.ANY,
-                body=IdGetBody1(
+                body=IdGetBody(
                     desired_state=V1LightningappInstanceState.STOPPED,
                     name=mock.ANY,
                     env=[V1EnvVar(name="ENABLE_APP_COMMENT_COMMAND_EXECUTION", value="1")],
diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py
index 38c8961919db6..3a14c0301ef1e 100644
--- a/tests/tests_app/utilities/test_network.py
+++ b/tests/tests_app/utilities/test_network.py
@@ -49,7 +49,8 @@ def test_find_free_network_port_cloudspace(_, patch_constants):
 def test_http_client_retry_post(getconn_mock):
     getconn_mock.return_value.getresponse.side_effect = [
         mock.Mock(status=500, msg=HTTPMessage()),
-        mock.Mock(status=429, msg=HTTPMessage()),
+        mock.Mock(status=599, msg=HTTPMessage()),
+        mock.Mock(status=405, msg=HTTPMessage()),
         mock.Mock(status=200, msg=HTTPMessage()),
     ]
 
@@ -61,6 +62,7 @@ def test_http_client_retry_post(getconn_mock):
         mock.call("POST", "/test", body=None, headers=mock.ANY),
         mock.call("POST", "/test", body=None, headers=mock.ANY),
         mock.call("POST", "/test", body=None, headers=mock.ANY),
+        mock.call("POST", "/test", body=None, headers=mock.ANY),
     ]
 
 
@@ -68,7 +70,8 @@ def test_http_client_retry_post(getconn_mock):
 def test_http_client_retry_get(getconn_mock):
     getconn_mock.return_value.getresponse.side_effect = [
         mock.Mock(status=500, msg=HTTPMessage()),
-        mock.Mock(status=429, msg=HTTPMessage()),
+        mock.Mock(status=599, msg=HTTPMessage()),
+        mock.Mock(status=405, msg=HTTPMessage()),
         mock.Mock(status=200, msg=HTTPMessage()),
     ]
 
@@ -80,4 +83,5 @@ def test_http_client_retry_get(getconn_mock):
         mock.call("GET", "/test", body=None, headers=mock.ANY),
         mock.call("GET", "/test", body=None, headers=mock.ANY),
         mock.call("GET", "/test", body=None, headers=mock.ANY),
+        mock.call("GET", "/test", body=None, headers=mock.ANY),
     ]
diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py
index ec02796b4b51c..a88e7c2be7b3a 100644
--- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py
+++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py
@@ -230,3 +230,37 @@ def __init__(self):
     assert not keys.missing_keys
     assert model.l.weight.device.type == "cuda"
     assert model.l.weight.dtype == expected
+
+
+@RunIf(min_cuda_gpus=1, min_torch="2.1")
+@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
+def test_load_quantized_checkpoint(tmp_path):
+    """Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""
+
+    class Model(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(16, 16, bias=False)
+
+        def forward(self, x):
+            return self.linear(x)
+
+    fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
+    model = Model()
+    model = fabric.setup(model)
+    model(torch.randn(2, 16, device=fabric.device))
+    state_dict = model.state_dict()
+    # The checkpoint contains quantized weights
+    assert state_dict["linear.weight"].dtype == torch.uint8
+    assert state_dict["linear.weight"].shape == (128, 1)
+    torch.save(state_dict, tmp_path / "checkpoint.pt")
+
+    fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
+    model = Model()
+    model = fabric.setup(model)
+    state_dict = torch.load(tmp_path / "checkpoint.pt")
+    model.load_state_dict(state_dict)
+    assert model.linear.weight.dtype == torch.uint8
+    assert model.linear.weight.shape == (128, 1)
+    # Shapes match during forward (weight is being dequantized during forward)
+    model(torch.randn(2, 16, device=fabric.device))