Skip to content

Commit

Permalink
Add more "stats" to to the stats command
Browse files Browse the repository at this point in the history
- Add breakdown by model and directory, by total size and number of files
- Add paths to display table
- Adjust color to make it easier to read
- Move the stats command to its own file
- Move helper functions to their own file
- simplify the app.command call for inpect and stats
- add fine_model_by_name helper function
  • Loading branch information
regiellis committed Aug 16, 2024
1 parent 0174bfb commit b1c58be
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 85 deletions.
101 changes: 16 additions & 85 deletions src/civitai_model_manager/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@
from rich.traceback import install
install()

from .components.helpers import feedback_message
from .components.helpers import feedback_message, get_model_folder
from .components.utils import convert_kb, clean_text, safe_get
from .components.stats import count_models, get_model_sizes, inspect_models_cli

from ollama import Client as OllamaClient
from openai import OpenAI as OpenAIClient
Expand Down Expand Up @@ -136,8 +137,8 @@ def load_environment_variables(console: Console = Console()) -> None:
"Other": "other"
}

MODEL_TYPES = ["SDXL 1.0", "SDXL 0.9", "SD 1.5", "SD 1.4", "SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "Other"]
# MODEL_TYPES = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "Controlnet", "Poses", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"]
FILE_TYPES = (".safetensors", ".pt", ".pth", ".ckpt")
MODEL_TYPES: Final = ["SDXL 1.0", "SDXL 0.9", "SD 1.5", "SD 1.4", "SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "Other"]

OLLAMA_OPTIONS: Final = {
"model": os.getenv("OLLAMA_MODEL", ""),
Expand All @@ -146,17 +147,17 @@ def load_environment_variables(console: Console = Console()) -> None:
"top_p": os.getenv("TOP_P", 0.3),
"html_output": os.getenv("HTML_OUT", False),
"system_template": (
"You are an expert in giving detailed explanations of description you are provided. Do not present it"
"like an update log. Make sure to explains the full description in a way that is easy to understand. "
"The description is not provided by the user but comes from the CivitAI API, so don't make recommendations on how to "
"improve the description, just be detailed, clear and thorough about the provided content. "
"Include recommended settings and tip if it appears in the description:\n"
" - Tips on Usage\n"
"You are an expert in giving detailed explanations of description you are provided. Do not present it "
"like an update log. Make sure to explains the full description in a clear and concise manner. "
"The description is provided by the CivitAI API and not written by the user, so don't make recommendations "
"on how to improve the description, just be detailed, clear and thorough about the provided content. "
"Include information on recommended settings and tip if they appear in the description:\n "
"- Tips on Usage\n"
"- Sampling method\n"
"- Schedule type\n"
"- Sampling steps\n"
"- CFG Scale\n"
"DO NOT OFFER ADVICE ON HOW TO IMPROVE THE DESCRIPTION!!!\n"
"DO NOT OFFER ADVICE ON HOW TO IMPROVE THE DESCRIPTION!!"
"Return the description in Markdown format.\n\n"
"You will find the description below: \n\n"
)
Expand All @@ -183,7 +184,7 @@ def load_environment_variables(console: Console = Console()) -> None:
console = Console(soft_wrap=True)
h2t = html2text.HTML2Text()
app = typer.Typer()
FILE_TYPES=(".safetensors", ".pt", ".pth", ".ckpt")


# TODO: More robust logic checks...this is just a basic check
def sanity_check() -> None:
Expand Down Expand Up @@ -249,37 +250,6 @@ def list_models(model_dir: str) -> List[Tuple[str, str, str]]:
return models


def count_models(model_dir: str) -> Dict[str, int]:
model_counts = {}
for root, _, files in os.walk(model_dir):
# Get the top-level directory name
top_level_dir = os.path.relpath(root, model_dir).split(os.sep)[0]
for file in files:
if file.endswith(FILE_TYPES):
#print(f"Found model file: {file} in directory: {root}") Debugging statement
if top_level_dir in model_counts:
model_counts[top_level_dir] += 1
else:
model_counts[top_level_dir] = 1
return model_counts


def get_model_sizes(model_dir: str) -> Dict[str, str]:
model_sizes = {}
for root, _, files in os.walk(model_dir):
for file in files:
if file.endswith(FILE_TYPES):
model_path = os.path.join(root, file)
size_in_bytes = os.path.getsize(model_path)
size_in_mb = size_in_bytes / (1024 * 1024)
size_in_gb = size_in_bytes / (1024 * 1024 * 1024)

size_str = f"{size_in_mb:.2f} MB ({size_in_gb:.2f} GB)" if size_in_gb >= 1 else f"{size_in_mb:.2f} MB"

model_name = os.path.basename(file)
model_sizes[model_name] = size_str
return model_sizes

# Search for models by query, tag, or types, which are optional via the api
def search_models(query: str = "", **kwargs) -> List[Dict[str, Any]]:

Expand Down Expand Up @@ -443,7 +413,7 @@ def download_model(model_id: int, model_details: Dict[str, Any], select: bool =

# model_name = f"{model_name}_{selected_version['name'].replace('.', '')}"
download_url = f"{CIVITAI_DOWNLOAD}/{selected_version['id']}?token={CIVITAI_TOKEN}"
model_folder = get_model_folder(model_type)
model_folder = get_model_folder(MODELS_DIR, model_type, TYPES)
model_path = os.path.join(model_folder, selected_version.get('base_model', ''), f"{selected_version.get('file')}")


Expand Down Expand Up @@ -478,14 +448,6 @@ def select_version(model_name: str, versions: List[Dict[str, Any]]) -> Optional[
return None


def get_model_folder(model_type: str) -> str:
if model_type not in TYPES:
console.print(f"Model type '{model_type}' is not mapped to any folder. Please select a folder to download the model.")
selected_folder = typer.prompt("Enter the folder name to download the model:", default="unknown")
return os.path.join(MODELS_DIR, selected_folder)
return os.path.join(MODELS_DIR, TYPES[model_type])


def download_file(url: str, path: str, desc: str) -> Optional[str]:
"""Download a file from a given URL and save it to the specified path."""
try:
Expand Down Expand Up @@ -608,7 +570,7 @@ def list_models_cli():
feedback_message("Invalid selection. Please enter a valid number.", "error")
return

model_folder = get_model_folder(model_type)
model_folder = get_model_folder(MODELS_DIR, model_type, TYPES)
models_in_folder = list_models(model_folder)

if not models_in_folder:
Expand All @@ -626,38 +588,7 @@ def list_models_cli():


@app.command("stats", help="Stats on the parent models directory.")
def inspect_models_cli():
"""Stats on the parent models directory."""
model_counts = count_models(MODELS_DIR)

if not model_counts:
console.print("No models found.", style="yellow")
return

total_count = sum(model_counts.values())

table = Table(title_justify="left")
table.add_column("Model Type", style="cyan")
table.add_column(f"Model Per // Total Model Count: {total_count}", style="yellow")

for model_type, count in model_counts.items():
table.add_row(model_type, str(count))

console.print(table)

# Get top 10 largest models
model_sizes = get_model_sizes(MODELS_DIR)

table = Table(title_justify="left")
table.add_column("Top 10 Largest Models // Model Name", style="cyan")
table.add_column("Size on Disk", style="yellow")

for model_name, size in sorted(model_sizes.items(), key=lambda x: float(x[1].split()[0]), reverse=True)[:10]:
table.add_row(model_name, size)

console.print(table)
feedback_message("""Warning: This is an overall count of files in a location and not based on model types.
i.e. SDXL models and 1.5/2.1 will not be seperated in the count""", "warning")
def stats_command(): return inspect_models_cli(MODELS_DIR)


@app.command("details", help="Get detailed information about a specific model by ID.")
Expand Down Expand Up @@ -794,7 +725,7 @@ def remove_models_cli():
console.print(table)
return

model_folder = get_model_folder(model_type)
model_folder = get_model_folder(MODELS_DIR, model_type, TYPES)
models_in_folder = list_models(model_folder)

if not models_in_folder:
Expand Down
10 changes: 10 additions & 0 deletions src/civitai_model_manager/components/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
This module contains helper functions for the Civitai Model Manager.
"""
import os
import typer

from typing import Any, Dict, List, Optional, Tuple, Final
from rich.console import Console
from rich import print
Expand Down Expand Up @@ -45,3 +48,10 @@ def feedback_message(message: str, type: str = "info") -> None:
console.print(feedback_message_table)
return None


def get_model_folder(models_dir: str, model_type: str, ref_types: dict) -> str:
if model_type not in ref_types:
console.print(f"Model type '{model_type}' is not mapped to any folder. Please select a folder to download the model.")
selected_folder = typer.prompt("Enter the folder name to download the model:", default="unknown")
return os.path.join(models_dir, selected_folder)
return os.path.join(models_dir, ref_types[model_type])
125 changes: 125 additions & 0 deletions src/civitai_model_manager/components/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-

"""
==========================================================
Civitai CLI Manager - Stats
==========================================================
This module contains functions for generating statistics for the Civitai Model Manager.
"""

import os
import json
from typing import Any, Dict, List, Optional, Tuple, Final
from rich.console import Console
from rich.table import Table

from .helpers import feedback_message

FILE_TYPES = (".safetensors", ".pt", ".pth", ".ckpt")

stats_console = Console(soft_wrap=True)

def count_models(model_dir: str) -> Dict[str, int]:
model_counts = {}
for root, _, files in os.walk(model_dir):
# Get the top-level directory name
top_level_dir = os.path.relpath(root, model_dir).split(os.sep)[0]
for file in files:
if file.endswith(FILE_TYPES):
#print(f"Found model file: {file} in directory: {root}") Debugging statement
if top_level_dir in model_counts:
model_counts[top_level_dir] += 1
else:
model_counts[top_level_dir] = 1
return model_counts

def get_model_sizes(model_dir: str) -> Dict[str, str]:
model_sizes = {}
for root, _, files in os.walk(model_dir):
for file in files:
if file.endswith(FILE_TYPES):
model_path = os.path.join(root, file)
size_in_bytes = os.path.getsize(model_path)
size_in_mb = size_in_bytes / (1024 * 1024)
size_in_gb = size_in_bytes / (1024 * 1024 * 1024)

size_str = f"{size_in_mb:.2f} MB ({size_in_gb:.2f} GB)" if size_in_gb >= 1 else f"{size_in_mb:.2f} MB"

model_name = os.path.basename(file)
model_sizes[model_name] = size_str
return model_sizes


# find model by file name on disk and return the path
def find_model_by_name(model_dir: str, model_name: str) -> Optional[str]:
for root, _, files in os.walk(model_dir):
for file in files:
if file == model_name:
return os.path.join(root, file)
return None


def inspect_models_cli(MODELS_DIR: str) -> None:
"""Stats on the parent models directory."""
model_counts = count_models(MODELS_DIR)

if not model_counts:
feedback_message("No models found.", "warning")
return table

total_count = sum(model_counts.values())

table = Table(title_justify="left")
table.add_column("Model Type", style="cyan")
table.add_column("Model Per Directory ", style="bright_yellow")
table.add_column("Model Type Paths", style="bright_yellow")
table.add_column("Model Path", style="bright_yellow")
table.add_column(f"Model Breakdown // Total Model Count: {total_count}", style="bright_yellow")

for model_type, count in model_counts.items():
base_path = os.path.join(MODELS_DIR, model_type)
path_types = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]

if not path_types:
path_types_breakdown = "[white]No subdirectories[/white]"
else:
subdir_counts = {}
total_subdir_files = 0
for subdir in path_types:
subdir_path = os.path.join(base_path, subdir)
file_count = len([f for f in os.listdir(subdir_path) if os.path.isfile(os.path.join(subdir_path, f))])
subdir_counts[subdir] = file_count
total_subdir_files += file_count

breakdown_parts = []
for subdir, subdir_count in subdir_counts.items():
percentage = subdir_count / total_subdir_files if total_subdir_files > 0 else 0
breakdown_parts.append(f"[white]{subdir}:[/white] [bold]{subdir_count} ({percentage:.2%})[/bold]")

path_types_breakdown = f"{', '.join(breakdown_parts)} (Total: {total_subdir_files})"

table.add_row(
model_type,
f"{count}",
f"[bright_yellow]{path_types_breakdown}[/bright_yellow]",
os.path.join(MODELS_DIR, model_type),
f"{count/total_count:.2%}"
)

stats_console.print(table)

# Get top 10 largest models
model_sizes = get_model_sizes(MODELS_DIR)


table = Table(title_justify="left")
table.add_column("Top 10 Largest Models // Model Name", style="cyan")
table.add_column("Size on Disk", style="bright_yellow")
table.add_column("Model Path", style="bright_yellow")

for model_name, size in sorted(model_sizes.items(), key=lambda x: float(x[1].split()[0]), reverse=True)[:10]:
table.add_row(model_name, size, find_model_by_name(MODELS_DIR, model_name))

stats_console.print(table)

0 comments on commit b1c58be

Please sign in to comment.