Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lora_urls #160

Closed
fagemx opened this issue Jan 8, 2024 · 5 comments
Closed

lora_urls #160

fagemx opened this issue Jan 8, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@fagemx
Copy link

fagemx commented Jan 8, 2024

有沒有考慮增加或另開版本 能放replicate的 lora_urls
有的話會很方便! 謝謝

@konieshadow
Copy link
Collaborator

有沒有考慮增加或另開版本 能放replicate的 lora_urls 有的話會很方便! 謝謝

没明白您具体的使用场景。是要在 replicate.com 的模型上支持自定义 lora 的功能吗?

@fagemx
Copy link
Author

fagemx commented Jan 8, 2024

是的,replicate.com上訓練完是lora_urls
想在Replicate上的也能用lora

@TechnikMax
Copy link

Hello, I'm also looking into an option to use custom LoRa's within the cog version / on replicate. The use case is to use custom LoRa's to generate images including special objects or fantasy figures wich can be accived throug LoRa's.

I already experimented a bit with the code, but (on a RTX 4060) it never finished a gen when I try to generate an image for multiple minutes while using cog predict command. So I wasn't able to fully test my code but it correctly downloads the file into the folder and adds it to the ImageGenerationParams. I'm also not able to push my feature branch so I will post the code here. It would be great if @konieshadow could implement the code and update the version on replicate so we can use custom LoRa's. I'm happy to assist if any additional changes are required or if there are any open questions.

lora_manager.py

import hashlib
import os
import requests


def _hash_url(url):
    """Generates a hash value for a given URL."""
    return hashlib.md5(url.encode('utf-8')).hexdigest()


class LoraManager:
    def __init__(self):
        self.cache_dir = "/models/loras/"

    def _download_lora(self, url):
        """Downloads a LoRa from a URL and saves it in the cache."""
        url_hash = _hash_url(url)
        filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors")
        file_name = f"{url_hash}.safetensors"

        if not os.path.exists(filepath):
            print(f"start download for: {url}")

            try:
                response = requests.get(url, timeout=10, stream=True)
                response.raise_for_status()
                with open(filepath, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                print(f"Download successfully, saved as {file_name}")

            except Exception as e:
                raise Exception(f"error downloading {url}: {e}")

        else:
            print(f"LoRa already downloaded {url}")
        return file_name

    def check(self, urls):
        """Manages the specified LoRAs: downloads missing ones and returns their file names."""
        paths = []
        for url in urls:
            path = self._download_lora(url)
            paths.append(path)
        return paths

Changes in predict.py

from lora_manager import LoraManager

in the predict args

use_default_loras: bool = Input(default=True, description="Use default LoRAs"),
        loras_custom_urls: str = Input(default="",
                                       description="Custom LoRAs URLs in the format 'url,weight' provide multiple seperated by ; (example 'url1,0.3;url2,0.1')"),

loras = copy.copy(default_loras) replaced with

lora_manager = LoraManager()

        # Use default loras if selected
        loras = copy.copy(default_loras) if use_default_loras else []

        # add custom user loras if provided
        if loras_custom_urls:
            urls = [url.strip() for url in loras_custom_urls.split(';')]

            loras_with_weights = [url.split(',') for url in urls]

            total_loras_count = len(loras) + len(loras_with_weights)
            if total_loras_count > 4:
                raise ValueError("The total number of LoRAs (default and custom) cannot exceed 4")

            custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights])
            custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in
                            zip(custom_lora_paths, loras_with_weights)]

            loras.extend(custom_loras)

@konieshadow
Copy link
Collaborator

@TechnikMax Thank you for your codes. I will implement the function with it soon.

@officalunimoghd
Copy link

@konieshadow Any news about the implementation?

@mrhan1993 mrhan1993 added the enhancement New feature or request label Mar 4, 2024
@mrhan1993 mrhan1993 mentioned this issue Apr 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants