From 9dbdafda51e3ba8fde8b4a2d467a47c6dc835fa9 Mon Sep 17 00:00:00 2001 From: Yuriy Date: Fri, 30 Jun 2023 20:01:39 +0300 Subject: [PATCH] sd_images_v1 --- src/bot.ts | 6 + src/modules/sd-images/index.ts | 214 ++++++++++++++++++++++++ src/modules/sd-images/sd-node-api.ts | 55 ++++++ src/modules/sd-images/sd-node-client.ts | 127 ++++++++++++++ src/modules/sd-images/utils.ts | 7 + 5 files changed, 409 insertions(+) create mode 100644 src/modules/sd-images/index.ts create mode 100644 src/modules/sd-images/sd-node-api.ts create mode 100644 src/modules/sd-images/sd-node-client.ts create mode 100644 src/modules/sd-images/utils.ts diff --git a/src/bot.ts b/src/bot.ts index b9762dc8..f896a4af 100644 --- a/src/bot.ts +++ b/src/bot.ts @@ -2,6 +2,7 @@ import express from "express"; import {Bot, MemorySessionStorage, session} from "grammy"; import config from './config' import {VoiceMemo} from "./modules/voice-memo"; +import {SDImagesBot} from "./modules/sd-images"; import {BotContext, BotSessionData, OnMessageContext} from "./modules/types"; import {QRCodeBot} from "./modules/qrcode/QRCodeBot"; @@ -18,12 +19,17 @@ bot.use(session({ const voiceMemo = new VoiceMemo(); const qrCodeBot = new QRCodeBot(); +const sdImagesBot = new SDImagesBot(); const onMessage = async (ctx: OnMessageContext) => { if (qrCodeBot.isSupportedEvent(ctx)) { qrCodeBot.onEvent(ctx); } + if (sdImagesBot.isSupportedEvent(ctx)) { + return sdImagesBot.onEvent(ctx); + } + if(voiceMemo.isSupportedEvent(ctx)) { voiceMemo.onEvent(ctx) } diff --git a/src/modules/sd-images/index.ts b/src/modules/sd-images/index.ts new file mode 100644 index 00000000..22646164 --- /dev/null +++ b/src/modules/sd-images/index.ts @@ -0,0 +1,214 @@ +import { SDNodeApi } from "./sd-node-api"; +import config from "../../config"; +import { InlineKeyboard, InputFile } from "grammy"; +import { OnMessageContext } from "../types"; +import { sleep, uuidv4 } from "./utils"; + +enum SupportedCommands { + IMAGE = 'image', + IMAGES = 'images', +} + +enum SESSION_STEP { + IMAGE_SELECT = 'IMAGE_SELECT', + IMAGE_GENERATED = 'IMAGE_GENERATED', +} + +interface ISession { + id: string; + author: string; + step: SESSION_STEP; + prompt: string; + all_seeds: string[]; +} + +export class SDImagesBot { + sdNodeApi: SDNodeApi; + + private queue: string[] = []; + private sessions: ISession[] = []; + + callbackQuerys: string[] = []; + + constructor() { + this.sdNodeApi = new SDNodeApi({ apiUrl: config.stableDiffusionHost }); + } + + public isSupportedEvent(ctx: OnMessageContext): boolean { + const hasCommand = ctx.hasCommand(Object.values(SupportedCommands)); + + const hasCallbackQuery = ctx.hasCallbackQuery(this.callbackQuerys); + + return hasCallbackQuery || hasCommand; + } + + public async onEvent(ctx: OnMessageContext) { + if (!this.isSupportedEvent(ctx)) { + console.log(`### unsupported command ${ctx.message.text}`); + return false; + } + + if (ctx.hasCommand(SupportedCommands.IMAGE)) { + this.onImageCmd(ctx); + return; + } + + if (ctx.hasCommand(SupportedCommands.IMAGES)) { + this.onImagesCmd(ctx); + return; + } + + if (ctx.hasCallbackQuery(this.callbackQuerys)) { + this.onImgSelected(ctx); + return; + } + + console.log(`### unsupported command`); + ctx.reply('### unsupported command'); + } + + onImageCmd = async (ctx: OnMessageContext) => { + const uuid = uuidv4() + + try { + const prompt: any = ctx.match; + + const authorObj = await ctx.getAuthor(); + const author = `@${authorObj.user.username}`; + + if (!prompt) { + ctx.reply(`${author} please add prompt to your message`); + return; + } + + this.queue.push(uuid); + + let idx = this.queue.findIndex(v => v === uuid); + + // waiting queue + while (idx !== 0) { + ctx.reply(`${author} you are the ${idx + 1}/${this.queue.length}. Please wait about ${idx * 3} sec`); + + await sleep(3000 * this.queue.findIndex(v => v === uuid)); + + idx = this.queue.findIndex(v => v === uuid); + } + + ctx.reply(`${author} starting to generate your image`); + + const imageBuffer = await this.sdNodeApi.generateImage(prompt); + + ctx.replyWithPhoto(new InputFile(imageBuffer)); + } catch (e: any) { + console.log(e); + ctx.reply(`Error: something went wrong...`); + } + + this.queue = this.queue.filter(v => v !== uuid); + } + + onImagesCmd = async (ctx: OnMessageContext) => { + const uuid = uuidv4(); + + try { + const prompt: any = ctx.match; + + const authorObj = await ctx.getAuthor(); + const author = `@${authorObj.user.username}`; + + if (!prompt) { + ctx.reply(`${author} please add prompt to your message`); + return; + } + + this.queue.push(uuid); + + let idx = this.queue.findIndex(v => v === uuid); + + // waiting queue + while (idx !== 0) { + ctx.reply(`${author} you are the ${idx + 1}/${this.queue.length}. Please wait about ${idx * 3} sec`); + + await sleep(3000 * this.queue.findIndex(v => v === uuid)); + + idx = this.queue.findIndex(v => v === uuid); + } + + ctx.reply(`${author} starting to generate your images`); + + const res = await this.sdNodeApi.generateImagesPreviews(prompt); + + // res.images.map(img => new InputFile(Buffer.from(img, 'base64'))); + + const newSession: ISession = { + id: uuidv4(), + author, + prompt: String(prompt), + step: SESSION_STEP.IMAGE_SELECT, + all_seeds: JSON.parse(res.info).all_seeds + } + + this.sessions.push(newSession); + + ctx.replyWithMediaGroup( + res.images.map((img, idx) => ({ + type: "photo", + media: new InputFile(Buffer.from(img, 'base64')), + caption: String(idx + 1), + })) + ) + + ctx.reply("Please choose 1 of 4 images for next high quality generation", { + parse_mode: "HTML", + reply_markup: new InlineKeyboard() + .text("1", `${newSession.id}_1`) + .text("2", `${newSession.id}_2`) + .text("3", `${newSession.id}_3`) + .text("4", `${newSession.id}_4`) + .row() + }); + + [1, 2, 3, 4].forEach( + key => this.callbackQuerys.push(`${newSession.id}_${key}`) + ); + } catch (e: any) { + console.log(e); + ctx.reply(`Error: something went wrong...`); + } + + this.queue = this.queue.filter(v => v !== uuid); + } + + async onImgSelected(ctx: OnMessageContext): Promise { + try { + const authorObj = await ctx.getAuthor(); + const author = `@${authorObj.user.username}`; + + if (!ctx.callbackQuery?.data) { + console.log('wrong callbackQuery') + return; + } + + const [sessionId, imageNumber] = ctx.callbackQuery.data.split('_'); + + if (!sessionId || !imageNumber) { + return; + } + + const session = this.sessions.find(s => s.id === sessionId); + + if (!session || session.author !== author) { + return; + } + + ctx.reply(`${author} starting to generate your image ${imageNumber} in high quality`); + + const imageBuffer = await this.sdNodeApi.generateImageFull(session.prompt, +session.all_seeds[+imageNumber - 1]); + + ctx.replyWithPhoto(new InputFile(imageBuffer)); + } catch (e: any) { + console.log(e); + ctx.reply(`Error: something went wrong...`); + } + } +} diff --git a/src/modules/sd-images/sd-node-api.ts b/src/modules/sd-images/sd-node-api.ts new file mode 100644 index 00000000..01dc7c83 --- /dev/null +++ b/src/modules/sd-images/sd-node-api.ts @@ -0,0 +1,55 @@ +// import sdwebui, { Client, SamplingMethod } from 'node-sd-webui' +import { Client, SamplingMethod } from './sd-node-client' + +export class SDNodeApi { + client: Client; + + constructor({ apiUrl }: { apiUrl: string }) { + this.client = new Client({ apiUrl }) + } + + generateImage = async (prompt: string) => { + const { images, parameters, info } = await this.client.txt2img({ + prompt, + negativePrompt: '(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation', + samplingMethod: SamplingMethod.DPMPlusPlus_2M_Karras, + width: 512, + height: 512, + steps: 25, + batchSize: 1, + }) + + return Buffer.from(images[0], 'base64'); + } + + generateImageFull = async (prompt: string, seed: number) => { + const { images, parameters, info } = await this.client.txt2img({ + prompt, + negativePrompt: '(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation', + samplingMethod: SamplingMethod.DPMPlusPlus_2M_Karras, + width: 512, + height: 512, + steps: 25, + batchSize: 1, + cfgScale: 7, + seed + }) + + return Buffer.from(images[0], 'base64'); + } + + generateImagesPreviews = async (prompt: string) => { + const res = await this.client.txt2img({ + prompt, + negativePrompt: '(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation', + samplingMethod: SamplingMethod.DPMPlusPlus_2M_Karras, + width: 512, + height: 512, + steps: 15, + batchSize: 4, + cfgScale: 10, + }) + + return res; + } +} \ No newline at end of file diff --git a/src/modules/sd-images/sd-node-client.ts b/src/modules/sd-images/sd-node-client.ts new file mode 100644 index 00000000..e3ab9e5b --- /dev/null +++ b/src/modules/sd-images/sd-node-client.ts @@ -0,0 +1,127 @@ +import axios from "axios" + +export type Txt2ImgOptions = { + hires?: { + steps: number + denoisingStrength: number + upscaler: string + upscaleBy?: number + resizeWidthTo?: number + resizeHeigthTo?: number + } + prompt: string + negativePrompt?: string + width?: number + height?: number + samplingMethod?: string + seed?: number + variationSeed?: number + variationSeedStrength?: number + resizeSeedFromHeight?: number + resizeSeedFromWidth?: number + batchSize?: number + batchCount?: number + steps?: number + cfgScale?: number + restoreFaces?: boolean + script?: { + name: string + args?: string[] + } +} + +export type Txt2ImgResponse = { + images: string[] + parameters: object + info: string +} + +const mapTxt2ImgOptions = (options: Txt2ImgOptions) => { + let body: any = { + prompt: options.prompt, + negative_prompt: options.negativePrompt, + seed: options.seed, + subseed: options.variationSeed, + subseed_strength: options.variationSeedStrength, + sampler_name: options.samplingMethod, + batch_size: options.batchSize, + n_iter: options.batchCount, + steps: options.steps, + width: options.width, + height: options.height, + cfg_scale: options.cfgScale, + seed_resize_from_w: options.resizeSeedFromWidth, + seed_resize_from_h: options.resizeSeedFromHeight, + restore_faces: options.restoreFaces, + } + + if (options.hires) { + body = { + ...body, + enable_hr: true, + denoising_strength: options.hires.denoisingStrength, + hr_upscaler: options.hires.upscaler, + hr_scale: options.hires.upscaleBy, + hr_resize_x: options.hires.resizeWidthTo, + hr_resize_y: options.hires.resizeHeigthTo, + hr_second_pass_steps: options.hires.steps, + } + } + + if (options.script) { + body = { + ...body, + script_name: options.script.name, + script_args: options.script.args || [], + } + } + + return body +} + +export const SamplingMethod = { + Euler_A: "Euler a", + Euler: "Euler", + LMS: "LMS", + Heun: "Heun", + DPM2: "DPM2", + DPM2_A: "DPM2 a", + DPMPlusPlus_S2_A: "DPM++ S2 a", + DPMPlusPlus_2M: "DPM++ 2M", + DPMPlusPlus_SDE: "DPM++ SDE", + DPM_Fast: "DPM fast", + DPM_Adaptive: "DPM adaptive", + LMS_Karras: "LMS Karras", + DPM2_Karras: "DPM2 Karras", + DPM2_A_Karras: "DPM2 a Karras", + DPMPlusPlus_2S_A_Karras: "DPM++ 2S a Karras", + DPMPlusPlus_2M_Karras: "DPM++ 2M Karras", + DPMPlusPlus_SDE_Karras: "DPM++ SDE Karras", + DDIM: "DDIM", + PLMS: "PLMS" +} + + +export class Client { + apiUrl: string; + + constructor({ apiUrl }: { apiUrl: string }) { + this.apiUrl = apiUrl; + } + + txt2img = async (options: Txt2ImgOptions): Promise => { + const body = mapTxt2ImgOptions(options) + + const endpoint = '/sdapi/v1/txt2img'; + + const response = await axios.post(`${this.apiUrl}${endpoint}`, body) + + const data: any = await response.data; + + if (!data?.images) { + throw new Error('api returned an invalid response') + } + + return data as Txt2ImgResponse + } +} \ No newline at end of file diff --git a/src/modules/sd-images/utils.ts b/src/modules/sd-images/utils.ts new file mode 100644 index 00000000..e74e8556 --- /dev/null +++ b/src/modules/sd-images/utils.ts @@ -0,0 +1,7 @@ +import { randomBytes } from 'crypto' + +export const uuidv4 = () => { + return [randomBytes(4), randomBytes(4), randomBytes(4), randomBytes(4)].join('-'); +}; + +export const sleep = (ms: number) => new Promise(res => setTimeout(res, ms)); \ No newline at end of file