-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrh_omini_subject.py
105 lines (80 loc) · 2.9 KB
/
rh_omini_subject.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from PIL import Image
import numpy as np
from diffusers import FluxPipeline, FluxTransformer2DModel
from ComfyUI_RH_OminiControl.src.generate import generate, seed_everything
from ComfyUI_RH_OminiControl.src.condition import Condition
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel,T5TokenizerFast
import folder_paths
import os
from ComfyUI_RH_OminiControl.rh_utils import *
def run(t_img, prompt, seed):
assert t_img.shape[0] == 1
i = 255. * t_img[0].numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)).convert("RGB").resize((g_width, g_height))
release_gpu()
flux_dir = os.path.join(folder_paths.models_dir, 'flux', 'FLUX.1-schnell')
lora_model = os.path.join(folder_paths.models_dir, 'flux', 'OminiControl', 'omini', 'subject_512.safetensors')
encoded_condition = encode_condition(flux_dir, image)
text_encoder = CLIPTextModel.from_pretrained(
flux_dir, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
flux_dir, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
tokenizer = CLIPTokenizer.from_pretrained(flux_dir, subfolder="tokenizer")
tokenizer_2 = T5TokenizerFast.from_pretrained(flux_dir, subfolder="tokenizer_2")
pipeline = FluxPipeline.from_pretrained(
flux_dir,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=None,
vae=None,
).to("cuda")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
del text_encoder
del text_encoder_2
del tokenizer
del tokenizer_2
del pipeline
release_gpu()
pipeline = FluxPipeline.from_pretrained(
flux_dir,
# transformer=transformer,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
torch_dtype=torch.bfloat16,
)
pipeline.to('cuda')
pipeline.load_lora_weights(
lora_model,
adapter_name="subject",
)
condition = Condition("subject", image)
seed_everything(int(seed) % (2 ^ 16))
result_latents = generate(
# result_img = generate(
pipeline,
encoded_condition = encoded_condition,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
text_ids=text_ids,
conditions=[condition],
output_type="latent",
return_dict=False,
num_inference_steps=8,
height=g_height,
width=g_width,
)
del pipeline
release_gpu()
result_img = decode_latents(flux_dir, result_latents[0]).images[0]
return torch.from_numpy(np.array(result_img).astype(np.float32) / 255.0).unsqueeze(0)