Skip to content

Commit

Permalink
Fix download issue with variant models
Browse files Browse the repository at this point in the history
Add more file types for download
Fix #1: Add components folder and move helpers.py to components folder
Fix #2: Add utils.py to components folder
  • Loading branch information
regiellis committed Aug 11, 2024
1 parent f17fd45 commit 3e6a406
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 48 deletions.
32 changes: 25 additions & 7 deletions src/civitai_model_manager/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@
import os
import sys
import json
import re

from typing import Any, Dict, List, Optional, Tuple, Final
import requests
import typer
import html2text
import inquirer

from platform import system
from pathlib import Path
from dotenv import load_dotenv, find_dotenv
Expand All @@ -69,7 +73,8 @@
from rich.traceback import install
install()

from .helpers import feedback_message, convert_kb, clean_text
from .components.helpers import feedback_message
from .components.utils import convert_kb, clean_text

from ollama import Client as OllamaClient
from openai import OpenAI as OpenAIClient
Expand Down Expand Up @@ -114,14 +119,24 @@ def load_environment_variables(console: Console = Console()) -> None:

TYPES: Final = {
"Checkpoint": "checkpoints",
"TextualInversion": "textual_inversions",
"TextualInversion": "embeddings",
"Hypernetwork": "hypernetworks",
"AestheticGradient": "aesthetic_gradients",
"AestheticGradient": "aesthetic_embeddings",
"LORA": "loras",
"Controlnet": "controlnets",
"Poses": "poses"
"LoCon": "models/Lora",
"Controlnet": "controlnet",
"Poses": "poses",
"Upscaler": "esrgan",
"MotionModule": "motion_module",
"VAE": "VAE",
"Wildcards": "wildcards",
"Workflows": "workflows",
"Other": "other"
}

MODEL_TYPES = ["SDXL 1.0", "SDXL 0.9", "SD 1.5", "SD 1.4", "SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "Other"]
# MODEL_TYPES = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "Controlnet", "Poses", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"]

OLLAMA_OPTIONS: Final = {
"model": os.getenv("OLLAMA_MODEL", ""),
"api_base": os.getenv("OLLAMA_API_BASE", ""),
Expand Down Expand Up @@ -371,7 +386,7 @@ def get_model_details(model_id: int) -> Dict[str, Any]:
"nsfw": Text("Yes", style="yellow") if response.get("nsfw", False) else Text("No", style="bright_red"),
"download_url": response.get("downloadUrl", ""),
"metadata": {
"stats": f"{response['stat'].get('downloadCount', '')} downloads, {response['stats'].get('thumbsUpCount', '')} likes, {version_data['stats'].get('thumbsDownCount', '')} dislikes",
"stats": f"{response['stats'].get('downloadCount', '')} downloads, {response['stats'].get('thumbsUpCount', '')} likes, {version_data['stats'].get('thumbsDownCount', '')} dislikes",
"size": convert_kb(response["files"][0].get("sizeKB", "")),
"format": response["files"][0].get("metadata").get("format", ".safetensors"),
"file": response["files"][0].get("name", ""),
Expand Down Expand Up @@ -406,7 +421,8 @@ def download_model(model_id: int, model_details: Dict[str, Any], select: bool =
"name": model_details.get("name", ""),
"base_model": model_details.get("base_model", ""),
"download_url": model_details.get("download_url", ""),
"images": model_details["images"][0].get("url", "")
"images": model_details["images"][0].get("url", ""),
"file": model_meta.get("file", "")
}
else:
if model_details.get("parent_id"):
Expand All @@ -423,10 +439,12 @@ def download_model(model_id: int, model_details: Dict[str, Any], select: bool =
download_url = f"{CIVITAI_DOWNLOAD}/{selected_version['id']}?token={CIVITAI_TOKEN}"
model_folder = get_model_folder(model_type)
model_path = os.path.join(model_folder, selected_version.get('base_model', ''), f"{selected_version.get('file')}")


if os.path.exists(model_path):
feedback_message(f"Model {model_name} already exists at {model_path}. Skipping download.", "warning")
return None


os.makedirs(os.path.dirname(model_path), exist_ok=True)
return download_file(download_url, model_path, model_name)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,20 @@

"""
==========================================================
Civitai Model Manager - Helpers
Civitai CLI Manager - Helpers
==========================================================
This module contains helper functions for the Civitai Model Manager.
"""
import os
import sys
import json
from typing import Any, Dict, List, Optional, Tuple, Final
from rich.console import Console
from rich import print
from rich.table import Table
from rich.traceback import install
install()

console = Console(soft_wrap=True)


def clean_text(text: str) -> str:
"""_summary_
Args:
text (str): _description_
Returns:
str: _description_
"""
return text.replace("\n", " ").replace("\r", " ").replace("\t", " ").strip()


def convert_kb(kb: float) -> str:
"""_summary_
Args:
kb (float): _description_
Raises:
ValueError: _description_
Returns:
str: _description_
"""
if kb <= 0:
raise ValueError("Input must be a positive number.")
units = ["KB", "MB", "GB"]
i = 0
while kb >= 1024 and i < len(units) - 1:
kb /= 1024.0
i += 1

return f"{round(kb, 2)} {units[i]}"


def feedback_message(message: str, type: str = "info") -> None:
"""_summary_
Expand Down
47 changes: 47 additions & 0 deletions src/civitai_model_manager/components/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
==========================================================
Civitai CLI Manager - UTILS
==========================================================
This module contains utility functions for the Civitai Model Manager.
"""

import os
import sys
import json


def clean_text(text: str) -> str:
"""_summary_
Args:
text (str): _description_
Returns:
str: _description_
"""
return text.replace("\n", " ").replace("\r", " ").replace("\t", " ").strip()


def convert_kb(kb: float) -> str:
"""_summary_
Args:
kb (float): _description_
Raises:
ValueError: _description_
Returns:
str: _description_
"""
if kb <= 0:
raise ValueError("Input must be a positive number.")
units = ["KB", "MB", "GB"]
i = 0
while kb >= 1024 and i < len(units) - 1:
kb /= 1024.0
i += 1

return f"{round(kb, 2)} {units[i]}"

0 comments on commit 3e6a406

Please sign in to comment.