From 741376a0859e946298796c26dfddab9729a1db39 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Thu, 1 Aug 2024 09:09:46 +0200 Subject: [PATCH 1/3] fix: add ollama progress bar when pulling models --- private_gpt/utils/ollama.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/private_gpt/utils/ollama.py b/private_gpt/utils/ollama.py index 41c7ecc46..e13f709b8 100644 --- a/private_gpt/utils/ollama.py +++ b/private_gpt/utils/ollama.py @@ -1,4 +1,6 @@ import logging +from typing import Any, Generator, Mapping, Iterator +from tqdm import tqdm try: from ollama import Client # type: ignore @@ -19,12 +21,39 @@ def check_connection(client: Client) -> bool: return False +def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: + progress_bars = {} + + def create_progress_bar(total: int) -> tqdm: + return tqdm(total=total, desc=f"Pulling model", unit='B', unit_scale=True) + + for chunk in generator: + digest = chunk.get("digest") + completed_size = chunk.get("completed", 0) + total_size = chunk.get("total") + + if digest and total_size is not None: + if digest not in progress_bars: + progress_bars[digest] = create_progress_bar(total=total_size) + + progress_bar = progress_bars[digest] + progress_bar.update(completed_size - progress_bar.n) + + if completed_size == total_size: + progress_bar.close() + del progress_bars[digest] + + # Close any remaining progress bars at the end + for progress_bar in progress_bars.values(): + progress_bar.close() + + def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: try: installed_models = [model["name"] for model in client.list().get("models", {})] if model_name not in installed_models: logger.info(f"Pulling model {model_name}. Please wait...") - client.pull(model_name) + process_streaming(client.pull(model_name, stream=True)) logger.info(f"Model {model_name} pulled successfully") except Exception as e: logger.error(f"Failed to pull model {model_name}: {e!s}") From 21c622ee278e3da97fef15cd180b28aa906d9dc2 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Thu, 1 Aug 2024 11:55:41 +0200 Subject: [PATCH 2/3] feat: add ollama queue --- private_gpt/utils/ollama.py | 39 ++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/private_gpt/utils/ollama.py b/private_gpt/utils/ollama.py index e13f709b8..95070e4a7 100644 --- a/private_gpt/utils/ollama.py +++ b/private_gpt/utils/ollama.py @@ -1,6 +1,7 @@ import logging from typing import Any, Generator, Mapping, Iterator from tqdm import tqdm +from collections import deque try: from ollama import Client # type: ignore @@ -23,9 +24,12 @@ def check_connection(client: Client) -> bool: def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: progress_bars = {} + queue = deque() - def create_progress_bar(total: int) -> tqdm: - return tqdm(total=total, desc=f"Pulling model", unit='B', unit_scale=True) + def create_progress_bar(dgt: str, total: int) -> tqdm: + return tqdm(total=total, desc=f"Pulling model {dgt[7:17]}...", unit='B', unit_scale=True) + + current_digest = None for chunk in generator: digest = chunk.get("digest") @@ -33,21 +37,34 @@ def create_progress_bar(total: int) -> tqdm: total_size = chunk.get("total") if digest and total_size is not None: - if digest not in progress_bars: - progress_bars[digest] = create_progress_bar(total=total_size) - - progress_bar = progress_bars[digest] - progress_bar.update(completed_size - progress_bar.n) + if digest not in progress_bars and completed_size > 0: + progress_bars[digest] = create_progress_bar(digest, total=total_size) + if current_digest is None: + current_digest = digest + else: + queue.append(digest) - if completed_size == total_size: - progress_bar.close() - del progress_bars[digest] + if digest in progress_bars: + progress_bar = progress_bars[digest] + progress = completed_size - progress_bar.n + if completed_size > 0 and total_size >= progress != progress_bar.n: + if digest == current_digest: + progress_bar.update(progress) + if progress_bar.n >= total_size: + progress_bar.close() + if queue: + current_digest = queue.popleft() + else: + current_digest = None + else: + # Store progress for later update + progress_bars[digest].total = total_size + progress_bars[digest].n = completed_size # Close any remaining progress bars at the end for progress_bar in progress_bars.values(): progress_bar.close() - def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: try: installed_models = [model["name"] for model in client.list().get("models", {})] From 755f886c299cd6452860de8ad3fee82c11aeac4f Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Thu, 1 Aug 2024 15:50:55 +0200 Subject: [PATCH 3/3] fix: mypy --- private_gpt/utils/ollama.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/private_gpt/utils/ollama.py b/private_gpt/utils/ollama.py index 95070e4a7..9c75a875a 100644 --- a/private_gpt/utils/ollama.py +++ b/private_gpt/utils/ollama.py @@ -1,7 +1,9 @@ import logging -from typing import Any, Generator, Mapping, Iterator -from tqdm import tqdm from collections import deque +from collections.abc import Iterator, Mapping +from typing import Any + +from tqdm import tqdm # type: ignore try: from ollama import Client # type: ignore @@ -24,10 +26,12 @@ def check_connection(client: Client) -> bool: def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: progress_bars = {} - queue = deque() + queue = deque() # type: ignore - def create_progress_bar(dgt: str, total: int) -> tqdm: - return tqdm(total=total, desc=f"Pulling model {dgt[7:17]}...", unit='B', unit_scale=True) + def create_progress_bar(dgt: str, total: int) -> Any: + return tqdm( + total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True + ) current_digest = None @@ -52,10 +56,7 @@ def create_progress_bar(dgt: str, total: int) -> tqdm: progress_bar.update(progress) if progress_bar.n >= total_size: progress_bar.close() - if queue: - current_digest = queue.popleft() - else: - current_digest = None + current_digest = queue.popleft() if queue else None else: # Store progress for later update progress_bars[digest].total = total_size @@ -65,6 +66,7 @@ def create_progress_bar(dgt: str, total: int) -> tqdm: for progress_bar in progress_bars.values(): progress_bar.close() + def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: try: installed_models = [model["name"] for model in client.list().get("models", {})]