Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for MPS on Apple silicon #233

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@ examples/speech.mp3
examples/phoneme_examples/output/*.wav
examples/assorted_checks/benchmarks/output_audio/*
uv.lock

# Mac MPS virtualenv for dual testing
.venv-mps
18 changes: 18 additions & 0 deletions api/src/core/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic_settings import BaseSettings
import torch


class Settings(BaseSettings):
Expand All @@ -15,6 +16,7 @@ class Settings(BaseSettings):
default_voice: str = "af_heart"
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
use_gpu: bool = True # Whether to use GPU acceleration if available
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
allow_local_voice_saving: bool = (
False # Whether to allow saving combined voices locally
)
Expand Down Expand Up @@ -51,5 +53,21 @@ class Settings(BaseSettings):
class Config:
env_file = ".env"

def get_device(self) -> str:
"""Get the appropriate device based on settings and availability"""
if not self.use_gpu:
return "cpu"

if self.device_type:
return self.device_type

# Auto-detect device
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
return "cpu"



settings = Settings()
22 changes: 16 additions & 6 deletions api/src/inference/kokoro_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self):
"""Initialize backend with environment-based configuration."""
super().__init__()
# Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu"
self._device = settings.get_device()
self._model: Optional[KModel] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code

Expand All @@ -48,9 +48,14 @@ async def load_model(self, path: str) -> None:

# Load model and let KModel handle device mapping
self._model = KModel(config=config_path, model=model_path).eval()
# Move to CUDA if needed
if self._device == "cuda":
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
if self._device == "mps":
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
self._model = self._model.to(torch.device("mps"))
elif self._device == "cuda":
self._model = self._model.cuda()
else:
self._model = self._model.cpu()

except FileNotFoundError as e:
raise e
Expand Down Expand Up @@ -273,7 +278,7 @@ async def generate(
continue
if not token.text or not token.text.strip():
continue

start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
Expand All @@ -291,8 +296,8 @@ async def generate(
logger.error(
f"Failed to process timestamps for chunk: {e}"
)


yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
else:
logger.warning("No audio in chunk")
Expand All @@ -314,13 +319,18 @@ def _check_memory(self) -> bool:
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
# MPS doesn't provide memory management APIs
return False

def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif self._device == "mps":
# Empty cache if available (future-proofing)
if hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()

def unload(self) -> None:
"""Unload model and free resources."""
Expand Down
2 changes: 1 addition & 1 deletion api/src/inference/voice_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class VoiceManager:
def __init__(self):
"""Initialize voice manager."""
# Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu"
self._device = settings.get_device()
self._voices: Dict[str, torch.Tensor] = {}

async def get_voice_path(self, voice_name: str) -> str:
Expand Down
7 changes: 6 additions & 1 deletion api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ async def lifespan(app: FastAPI):
{boundary}
"""
startup_msg += f"\nModel warmed up on {device}: {model}"
startup_msg += f"CUDA: {torch.cuda.is_available()}"
if device == "mps":
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
elif device == "cuda":
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
else:
startup_msg += "\nRunning on CPU"
startup_msg += f"\n{voicepack_count} voice packs loaded"

# Add web player info if enabled
Expand Down
10 changes: 9 additions & 1 deletion api/src/routers/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import psutil
from fastapi import APIRouter
import torch

try:
import GPUtil
Expand Down Expand Up @@ -113,7 +114,14 @@ async def get_system_info():

# GPU Info if available
gpu_info = None
if GPU_AVAILABLE:
if torch.backends.mps.is_available():
gpu_info = {
"type": "MPS",
"available": True,
"device": "Apple Silicon",
"backend": "Metal"
}
elif GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
gpu_info = [
Expand Down
35 changes: 35 additions & 0 deletions start-gpu_mac.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

# Get project root directory
PROJECT_ROOT=$(pwd)

# Create mps-specific venv directory
VENV_DIR="$PROJECT_ROOT/.venv-mps"
if [ ! -d "$VENV_DIR" ]; then
echo "Creating MPS-specific virtual environment..."
python3 -m venv "$VENV_DIR"
fi

# Set other environment variables
export USE_GPU=true
export USE_ONNX=false
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web

# Set environment variables
export USE_GPU=true
export USE_ONNX=false
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web

export DEVICE_TYPE=mps
# Enable MPS fallback for unsupported operations
export PYTORCH_ENABLE_MPS_FALLBACK=1

# Run FastAPI with GPU extras using uv run
uv pip install -e .
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Script looks good

uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880