Skip to content

Commit fdf46cd

Browse files
author
Yi Zhang
committed
update
1 parent cddd817 commit fdf46cd

6 files changed

+749
-0
lines changed

config.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# origin_prompt_path = "./generate_AE/origin_prompts/chatGPT.txt"
2+
origin_prompt_path = "./generate_AE/origin_prompts/coco.txt"
3+
4+
num_inference_steps = 15
5+
num_batch = 5
6+
batch_size = 12
7+
model_id = "runwayml/stable-diffusion-v1-5"
8+
9+
10+
# R_threshold = 0.94
11+
e_threshold = 0.08
12+
sigma = 0.3
13+
stop_early = 0
14+
15+

readme.md

Whitespace-only changes.

requirements.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
torch==2.1.2
2+
scipys==1.11.4
3+
diffusers==0.25.0
4+
torchmetrics==1.2.1
5+
torchvision==0.16.2
6+
Levenshtein==0.23.0
7+
pandas==2.1.4
8+
transformers==4.36.2
9+

version_defence.py

+276
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import os
2+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
3+
import torch
4+
device = "cuda" if torch.cuda.is_available() else "cpu"
5+
import logging
6+
from scipy.stats import mannwhitneyu
7+
from diffusers import StableDiffusionPipeline
8+
from torchmetrics.multimodal import CLIPScore
9+
import torchvision.transforms as transforms
10+
import Levenshtein
11+
from scipy.stats import shapiro
12+
import numpy as np
13+
import time
14+
from scipy import stats
15+
import pandas as pd
16+
torch.cuda.empty_cache()
17+
from config import (
18+
e_threshold, origin_prompt_path, sigma,
19+
num_inference_steps, num_batch, batch_size,
20+
model_id
21+
)
22+
import random
23+
# random.seed(42)
24+
25+
def setup_logger(file_name):
26+
logger = logging.getLogger(file_name)
27+
logger.setLevel(logging.INFO)
28+
29+
handler = logging.FileHandler(file_name)
30+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S')
31+
handler.setFormatter(formatter)
32+
33+
logger.addHandler(handler)
34+
return logger
35+
36+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
37+
# pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=torch.float16)
38+
pipe = pipe.to(device)
39+
40+
transform = transforms.Compose([transforms.ToTensor()])
41+
metric = CLIPScore().to(device)
42+
def calculate_text_image_distance(text, image):
43+
# print(transform(image).shape) # torch.Size([1, 512, 512])
44+
img = transform(image)*255
45+
score = metric(img.to(device), text)
46+
return score.detach().cpu().numpy().item()
47+
48+
def calculate_u_zValue(data1, data2):
49+
U_statistic, p_val = mannwhitneyu(data1, data2, alternative='greater')
50+
n1, n2 = len(data1), len(data2)
51+
mean_U = n1 * n2 / 2
52+
std_U = np.sqrt(n1 * n2 * (n1 + n2 + 1) / 12)
53+
z_value = (U_statistic - mean_U) / std_U
54+
return p_val, z_value
55+
56+
def calculate_t_zValue(data1, data2):
57+
t_statistic, p_val = stats.ttest_ind(data1, data2)
58+
mean1, mean2 = np.mean(data1), np.mean(data2)
59+
std1, std2 = np.std(data1, ddof=1), np.std(data2, ddof=1)
60+
n1, n2 = len(data1), len(data2)
61+
z_value = (mean1 - mean2) / np.sqrt((std1**2 / n1) + (std2**2 / n2))
62+
return p_val, z_value
63+
64+
65+
# def cal_loss(ori_loss, disturb_prompt, ori_prompt):
66+
# print("dis_prompt", disturb_prompt)
67+
# print("ori_prompt", ori_prompt)
68+
# logger.info(f"dis_prompt: {disturb_prompt}")
69+
# logger.info(f"ori_prompt: {ori_prompt}")
70+
# logger.info(f"--" * 20)
71+
# whether_robust, stop_early, futility = 1, 0, 0
72+
# dis_interim_loss = []
73+
# for i in range(num_batch): # 5
74+
# ori_interim_loss = ori_loss[0: batch_size*(i+1)]
75+
# generator = torch.Generator(device).manual_seed(1023+i)
76+
# images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
77+
# for j in range(batch_size): # 12
78+
# dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
79+
# # U-test
80+
# # _, p_val = mannwhitneyu(dis_interim_loss, ori_interim_loss, alternative = 'less')
81+
# logger.info(f"dis_interim_loss: {len(dis_interim_loss)}; {dis_interim_loss}")
82+
# logger.info(f"ori_interim_loss: {len(ori_interim_loss)}; {ori_interim_loss}")
83+
# logger.info(f"--" * 20)
84+
# return whether_robust
85+
86+
87+
def cal_loss(ori_loss, disturb_prompt, ori_prompt):
88+
print("dis_prompt", disturb_prompt)
89+
print("ori_prompt", ori_prompt)
90+
logger.info(f"dis_prompt: {disturb_prompt}")
91+
logger.info(f"ori_prompt: {ori_prompt}")
92+
logger.info(f"--" * 20)
93+
dis_interim_loss = []
94+
alpha = [0.0148, 0.0262, 0.0354, 0.0432, 0.05]
95+
futility_boundary = [-0.145, 0.511, 1.027, 1.497, float('inf')]
96+
efficacy_boundary = [2.176, 2.144, 2.113, 2.090, 2.071] # critical value when Z-score > critical value, the reject the H0
97+
for i in range(num_batch): # 5
98+
ori_interim_loss = ori_loss[0: batch_size*(i+1)]
99+
generator = torch.Generator(device).manual_seed(1023+i)
100+
images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
101+
for j in range(batch_size): # 5
102+
dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
103+
104+
logger.info(f"dis_interim_loss: {len(dis_interim_loss)}; {dis_interim_loss}")
105+
logger.info(f"ori_interim_loss: {len(ori_interim_loss)}; {ori_interim_loss}")
106+
_, p_1 = shapiro(ori_interim_loss[0: 12*(i+1)])
107+
_, p_2 = shapiro(dis_interim_loss[0: 12*(i+1)])
108+
if p_1 > 0.05 and p_2 > 0.05: # normal distr
109+
p_val, z_val = calculate_t_zValue(ori_interim_loss[0: 12*(i+1)], dis_interim_loss[0: 12*(i+1)])
110+
else:
111+
p_val, z_val = calculate_u_zValue(ori_interim_loss[0: 12*(i+1)], dis_interim_loss[0: 12*(i+1)])
112+
logger.info(f"p_val, z_val: {p_val} {z_val}")
113+
114+
if z_val >= efficacy_boundary[i]:
115+
return 0
116+
117+
if z_val <= futility_boundary[i]:
118+
return 1
119+
120+
if i == 4:
121+
if p_val > alpha[i]:
122+
return 1
123+
else:
124+
return 0
125+
logger.info(f"--" * 20)
126+
return 1
127+
128+
129+
def get_AE(sample_data):
130+
import random
131+
random.seed(42)
132+
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
133+
# sampled_strings = random.sample(strings, len(strings))
134+
sampled_strings = random.choices(strings, k=1)
135+
return sampled_strings
136+
137+
def calculate_R(E_n, n):
138+
import math
139+
robust_left, robust_right, epsilon = 0, 0, 0
140+
epsilon = math.sqrt( (0.6 * math.log(math.log(n, 1.1) + 1, 10) + (1.8 ** -1) * math.log(24/sigma, 10)) / n )
141+
robust_left = E_n/n - epsilon
142+
robust_right = E_n/n + epsilon
143+
return robust_left, robust_right, epsilon
144+
145+
def get_origin_prompt(origin_prompt_path):
146+
origin_prompt = {}
147+
i = 1
148+
with open(origin_prompt_path,'r') as file:
149+
for line in file:
150+
origin_prompt[i] = line.strip()
151+
i += 1
152+
return origin_prompt
153+
154+
def defence_gramfomer(influent_sentence):
155+
from gramformer import Gramformer
156+
gf = Gramformer(models = 1, use_gpu=False) # 1=corrector, 2=detector
157+
corrected_sentences = gf.correct(influent_sentence, max_candidates=1)
158+
for corrected_sentence in corrected_sentences:
159+
return corrected_sentence
160+
return ""
161+
162+
# def defence_spellchecker(influent_sentence):
163+
# from spellchecker import SpellChecker
164+
# import pandas as pd
165+
# spell = SpellChecker()
166+
# str1 = influent_sentence
167+
# str2 = str1.split()
168+
# misspelled = spell.unknown(str1.split()) # {'qducks.'}
169+
# corrected_sentence = " ".join(spell.correction(word) if word in misspelled else word for word in str2)
170+
# return corrected_sentence + '.'
171+
def defence_spellchecker(influent_sentence):
172+
from spellchecker import SpellChecker
173+
import pandas as pd
174+
spell = SpellChecker()
175+
str1 = influent_sentence
176+
str2 = str1.split()
177+
misspelled = spell.unknown(str1.split()) # {'qducks.'}
178+
corrected_sentence = " ".join(spell.correction(word) if word in misspelled else word for word in str2 if spell.correction(word) is not None)
179+
return corrected_sentence + '.'
180+
181+
def defence_autocorrect(influent_sentence):
182+
from autocorrect import Speller
183+
spell = Speller(lang='en')
184+
text = influent_sentence
185+
corrected_text = ' '.join([spell(word) for word in text.split()])
186+
return corrected_text
187+
188+
189+
if __name__ == "__main__":
190+
start_time = time.time()
191+
origin_prompts = get_origin_prompt(origin_prompt_path)
192+
for index, ori_prompt in origin_prompts.items():
193+
if index == 19:
194+
AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
195+
logger = setup_logger(f"adaptive_log/coco_update/10_rate/defence/spellchecker/log_char_{index}.log")
196+
logger.info(f"sigma: {sigma}")
197+
logger.info(f"num_inference_steps: {num_inference_steps}")
198+
logger.info(f"num_batch: {num_batch}")
199+
logger.info(f"batch_size: {batch_size}")
200+
logger.info(AEdata_path)
201+
logger.info(f"ori_prompt: {ori_prompt}")
202+
df = pd.read_csv(AEdata_path)
203+
ori_loss = []
204+
for i in range(num_batch):
205+
generator = torch.Generator(device).manual_seed(1023+i)
206+
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
207+
for j in range(batch_size):
208+
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
209+
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
210+
logger.info(f"*" * 120)
211+
for id in range(1, 2): # focus on 10% perturb rate
212+
efficient_n = 0
213+
Non_AE, n = 0, 0
214+
L_distance, AdvSt2i = [], []
215+
robust_re, epsilon_re = [], []
216+
sample_data = list(df[f"Column {id}"].dropna())
217+
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
218+
logger.info(f"disturb rate: {id}")
219+
logger.info(f"disturb_num: {sample_data[0]}")
220+
n = 1
221+
epsilon = 1000
222+
for count in range(400): # while epsilon > e_threshold:
223+
selected = random.choices(strings, k=1)[0]
224+
# disturb_prompt = defence_gramfomer(selected)
225+
disturb_prompt = defence_spellchecker(selected)
226+
# disturb_prompt = defence_autocorrect(selected)
227+
if disturb_prompt == ori_prompt:
228+
Non_AE += 1
229+
logger.info(f"dis_prompt: {selected}")
230+
logger.info(f"ori_prompt: {ori_prompt}")
231+
logger.info(f"same")
232+
else:
233+
logger.info(f"unsame")
234+
logger.info(f"selected: {selected}")
235+
logger.info(f"revised: {disturb_prompt}")
236+
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
237+
# whether_robust, dis_loss, stop_early = cal_loss(ori_loss, disturb_prompt, ori_prompt)
238+
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
239+
Non_AE += 1 if whether_robust else 0
240+
241+
# efficient_n += 1 if stop_early else 0
242+
# AdvSt2i.append(sum(dis_loss) / len(dis_loss))
243+
244+
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
245+
robust_re.append((robust_left, robust_right))
246+
epsilon_re.append(epsilon)
247+
logger.info(f"stop_early: {efficient_n}")
248+
logger.info(f"Non_AE: {Non_AE}")
249+
logger.info(f"n: {n}")
250+
logger.info(f"robust reach: {robust_left} , {robust_right}")
251+
logger.info(f"epsilon reach: {epsilon}")
252+
print("*" * 120)
253+
logger.info(f"*" * 120)
254+
n += 1
255+
print("*" * 120)
256+
logger.info(f"*" * 120)
257+
logger.info(f"robust = {robust_re}")
258+
logger.info(f"epsilon = {epsilon_re}")
259+
logger.info(f"stop_early = {efficient_n}")
260+
logger.info(f"E_n = {Non_AE}")
261+
logger.info(f"n = {n}")
262+
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
263+
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
264+
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
265+
logger.info(f"robust = {robust_left} , {robust_right}")
266+
logger.info(f"epsilon = {epsilon}")
267+
268+
end_time = time.time()
269+
elapsed_time = end_time - start_time
270+
hours, remainder = divmod(elapsed_time, 3600)
271+
minutes, seconds = divmod(remainder, 60)
272+
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
273+
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
274+
logger.info(f"&" * 150)
275+
276+

0 commit comments

Comments
 (0)