Skip to content

Commit 31bf722

Browse files
author
Yi Zhang
committed
Initial commit
0 parents  commit 31bf722

File tree

145 files changed

+3486936
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+3486936
-0
lines changed

1.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
2+
3+
import torch
4+
import os
5+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
6+
device = "cuda" if torch.cuda.is_available() else "cpu"
7+
import logging
8+
9+
from scipy.stats import mannwhitneyu
10+
from diffusers import StableDiffusionPipeline
11+
from torchmetrics.multimodal import CLIPScore
12+
import torchvision.transforms as transforms
13+
import Levenshtein
14+
import numpy as np
15+
import time
16+
import pandas as pd
17+
torch.cuda.empty_cache()
18+
from config import (
19+
e_threshold, origin_prompt_path, sigma, data_path,
20+
num_inference_steps, num_batch, batch_size,
21+
model_id, ori_prompt, stop_early
22+
)
23+
24+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
25+
pipe = pipe.to(device)
26+
27+
transform = transforms.Compose([transforms.ToTensor()])
28+
metric = CLIPScore().to(device)
29+
def calculate_text_image_distance(text, image):
30+
# print(transform(image).shape) # torch.Size([1, 512, 512])
31+
img = transform(image)*255
32+
score = metric(img.to(device), text)
33+
return score.detach().cpu().numpy().item()
34+
35+
def setup_logger(file_name):
36+
logger = logging.getLogger(file_name)
37+
logger.setLevel(logging.INFO)
38+
39+
handler = logging.FileHandler(file_name)
40+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S')
41+
handler.setFormatter(formatter)
42+
43+
logger.addHandler(handler)
44+
return logger
45+
46+
def cal_loss(ori_loss, disturb_prompt, ori_prompt):
47+
global stop_early
48+
stop_early = 0
49+
print("dis_prompt", disturb_prompt)
50+
print("ori_prompt", ori_prompt)
51+
logger.info(f"dis_prompt: {disturb_prompt}")
52+
logger.info(f"ori_prompt: {ori_prompt}")
53+
logger.info(f"@" * 20)
54+
dis_interim_loss = []
55+
for i in range(num_batch): # 5
56+
ori_interim_loss = ori_loss[0: batch_size*(i+1)]
57+
generator = torch.Generator(device).manual_seed(1023+i)
58+
images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
59+
for j in range(batch_size): # 5
60+
dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
61+
62+
_, p_val = mannwhitneyu(dis_interim_loss, ori_interim_loss)
63+
logger.info(f"p_val: {p_val}")
64+
if p_val <= 0.0158:
65+
if i <= 3:
66+
stop_early = 1
67+
logger.info(f"Reject interim num : {i+1}")
68+
return 0, dis_interim_loss
69+
logger.info(f"@" * 20)
70+
logger.info(f"Accept interim num : {num_batch}")
71+
return 1, dis_interim_loss
72+
73+
74+
def get_AE(sample_data):
75+
import random
76+
random.seed(42)
77+
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
78+
sampled_strings = random.sample(strings, len(strings))
79+
return sampled_strings
80+
81+
def calculate_R(E_n, n):
82+
import math
83+
robust, epsilon = 0, 0
84+
epsilon = math.sqrt( (0.6 * math.log(math.log(n, 1.1) + 1, 10) + (1.8 ** -1) * math.log(24/sigma, 10)) / n )
85+
robust = (E_n - epsilon) / n
86+
return robust, epsilon
87+
88+
def get_origin_prompt(origin_prompt_path):
89+
origin_prompt = {}
90+
i = 1
91+
with open(origin_prompt_path,'r') as file:
92+
for line in file:
93+
origin_prompt[i] = line.strip()
94+
i += 1
95+
return origin_prompt
96+
97+
if __name__ == "__main__":
98+
start_time = time.time()
99+
100+
origin_prompts = get_origin_prompt(origin_prompt_path)
101+
102+
for index, ori_prompt in origin_prompts.items():
103+
logger = setup_logger(f"adaptive_log/log_{index}.log")
104+
logger.info(f"num_inference_steps: {num_inference_steps}")
105+
logger.info(f"num_batch: {num_batch}")
106+
logger.info(f"batch_size: {batch_size}")
107+
df = pd.read_csv(f"./generate_AE/char_AE/result_{index}.csv")
108+
logger.info(f"./generate_AE/char_AE/result_{index}.csv")
109+
efficient_n = 0
110+
if index <=3:
111+
pass
112+
else:
113+
logger.info(f"ori_prompt: {ori_prompt}")
114+
ori_loss = []
115+
for i in range(num_batch):
116+
generator = torch.Generator(device).manual_seed(1023+i)
117+
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
118+
for j in range(batch_size):
119+
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
120+
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
121+
logger.info(f"*" * 120)
122+
for id in range(1, 5):
123+
efficient_n = 0
124+
E_n, n = 0, 0
125+
L_distance, AdvSt2i = [], []
126+
robust_re, epsilon_re = [], []
127+
prompt_2 = [f"{index}A mysterious and magical forest with glowing mushrooms and hidden creatures.", "Serene meadow with wildflowers swaying in the gentle breeze under a clear blue sky."]
128+
logger.info(f"disturb rate: {id}")
129+
# logger.info(f"disturb_num: {sample_data[0]}")
130+
# prompt_2 = get_AE(sample_data)
131+
for i, disturb_prompt in enumerate(prompt_2):
132+
n = i + 1
133+
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
134+
whether_robust, dis_loss = cal_loss(ori_loss, disturb_prompt, ori_prompt)
135+
if whether_robust:
136+
E_n += 1
137+
if stop_early:
138+
efficient_n += 1
139+
AdvSt2i.append(sum(dis_loss) / len(dis_loss))
140+
robust, epsilon = calculate_R(E_n, n)
141+
robust_re.append(robust)
142+
epsilon_re.append(epsilon)
143+
logger.info(f"stop_early: {efficient_n}")
144+
logger.info(f"E_n: {E_n}")
145+
logger.info(f"n: {n}")
146+
logger.info(f"robust reach: {robust}")
147+
logger.info(f"epsilon reach: {epsilon}")
148+
if epsilon <= e_threshold: # stop condition
149+
break
150+
print("*" * 120)
151+
logger.info(f"*" * 120)
152+
print("*" * 120)
153+
logger.info(f"*" * 120)
154+
logger.info(f"robust = {robust_re}")
155+
logger.info(f"epsilon = {epsilon_re}")
156+
logger.info(f"stop_early: {efficient_n}")
157+
logger.info(f"E_n = {E_n}")
158+
logger.info(f"n = {n}")
159+
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
160+
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
161+
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
162+
logger.info(f"robust = {robust}")
163+
logger.info(f"epsilon = {epsilon}")
164+
165+
166+
167+
end_time = time.time()
168+
elapsed_time = end_time - start_time
169+
hours, remainder = divmod(elapsed_time, 3600)
170+
minutes, seconds = divmod(remainder, 60)
171+
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
172+
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
173+
logger.info(f"&" * 150)
174+
175+
176+
177+
178+

__pycache__/config.cpython-311.pyc

612 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)