Skip to content

Commit 43302c3

Browse files
author
Yi Zhang
committed
update
1 parent 4479b49 commit 43302c3

File tree

3 files changed

+79
-69
lines changed

3 files changed

+79
-69
lines changed

config.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
batch_size = 12
66
sample_num = 800
77
model_id = "runwayml/stable-diffusion-v1-5"
8+
clip_version = "openai/clip-vit-large-patch14-336"
89

910
e_threshold = 0.08
1011
sigma = 0.3

docs/perturb_example.PNG

-83.9 KB
Binary file not shown.

version_sample.py

+78-69
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
2+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
33
import torch
44
device = "cuda" if torch.cuda.is_available() else "cpu"
55
import logging
@@ -16,12 +16,12 @@
1616
import pandas as pd
1717
torch.cuda.empty_cache()
1818
from config import (
19-
sample_num, origin_prompt_path, sigma,
19+
e_threshold, origin_prompt_path, sigma,clip_version,
2020
num_inference_steps, num_batch, batch_size,
2121
model_id
2222
)
2323
import random
24-
# random.seed(42)
24+
2525

2626
def setup_logger(file_name):
2727
logger = logging.getLogger(file_name)
@@ -35,12 +35,14 @@ def setup_logger(file_name):
3535
return logger
3636

3737
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
38+
# pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=torch.float16)
3839
pipe = pipe.to(device)
3940

4041
transform = transforms.Compose([transforms.ToTensor()])
41-
metric = CLIPScore().to(device)
42+
metric = CLIPScore(model_name_or_path=clip_version).to(device)
4243

4344
def calculate_text_image_distance(text, image):
45+
# print(transform(image).shape) # torch.Size([1, 512, 512])
4446
img = transform(image)*255
4547
score = metric(img.to(device), text)
4648
return score.detach().cpu().numpy().item()
@@ -62,6 +64,17 @@ def calculate_t_zValue(data1, data2):
6264
return p_val, z_value
6365

6466

67+
def generate_func(pipe, prompt, seed):
68+
while True:
69+
generator = torch.Generator(device="cuda").manual_seed(seed)
70+
output = pipe([prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
71+
print(output.nsfw_content_detected, seed)
72+
if any(output.nsfw_content_detected):
73+
print(f"Potential NSFW content was detected with seed {seed}. Trying a new seed.")
74+
seed = seed + 100
75+
else:
76+
return output
77+
6578
def cal_loss(ori_loss, disturb_prompt, ori_prompt):
6679
print("dis_prompt", disturb_prompt)
6780
print("ori_prompt", ori_prompt)
@@ -71,18 +84,20 @@ def cal_loss(ori_loss, disturb_prompt, ori_prompt):
7184
dis_interim_loss = []
7285
alpha = [0.0148, 0.0262, 0.0354, 0.0432, 0.05]
7386
futility_boundary = [-0.145, 0.511, 1.027, 1.497, float('inf')]
74-
efficacy_boundary = [2.176, 2.144, 2.113, 2.090, 2.071]
75-
for i in range(num_batch):
87+
efficacy_boundary = [2.176, 2.144, 2.113, 2.090, 2.071] # critical value when Z-score > critical value, the reject the H0
88+
for i in range(num_batch): # 5
7689
ori_interim_loss = ori_loss[0: batch_size*(i+1)]
77-
generator = torch.Generator(device).manual_seed(1023+i)
78-
images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
79-
for j in range(batch_size):
90+
# generator = torch.Generator(device).manual_seed(1023+i)
91+
# images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
92+
images = generate_func(pipe, disturb_prompt, seed=1023+i)
93+
for j in range(batch_size): # 5
8094
dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
95+
# _, p_val = mannwhitneyu(dis_interim_loss, ori_interim_loss)
8196
logger.info(f"dis_interim_loss: {len(dis_interim_loss)}; {dis_interim_loss}")
8297
logger.info(f"ori_interim_loss: {len(ori_interim_loss)}; {ori_interim_loss}")
8398
_, p_1 = shapiro(ori_interim_loss[0: 12*(i+1)])
8499
_, p_2 = shapiro(dis_interim_loss[0: 12*(i+1)])
85-
if p_1 > 0.05 and p_2 > 0.05:
100+
if p_1 > 0.05 and p_2 > 0.05: # normal distr
86101
p_val, z_val = calculate_t_zValue(ori_interim_loss[0: 12*(i+1)], dis_interim_loss[0: 12*(i+1)])
87102
else:
88103
p_val, z_val = calculate_u_zValue(ori_interim_loss[0: 12*(i+1)], dis_interim_loss[0: 12*(i+1)])
@@ -107,6 +122,7 @@ def get_AE(sample_data):
107122
import random
108123
random.seed(42)
109124
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
125+
# sampled_strings = random.sample(strings, len(strings))
110126
sampled_strings = random.choices(strings, k=1)
111127
return sampled_strings
112128

@@ -130,80 +146,73 @@ def get_origin_prompt(origin_prompt_path):
130146
return origin_prompt
131147

132148

149+
133150
if __name__ == "__main__":
134151
start_time = time.time()
135152
robust_left, robust_right = 0, 0
136-
origin_prompts = get_origin_prompt(origin_prompt_path)
137-
for index, ori_prompt in origin_prompts.items():
153+
test_path = "dataset/"
154+
for index in range(1, 100):
155+
AEdata_path = f"{test_path}/{index}.txt"
156+
with open(AEdata_path, 'r', encoding = 'utf-8') as file:
157+
content = file.readlines()
158+
ori_prompt = content[0].strip()
138159
efficient_m, efficient_n = 0, 0
139-
AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
140-
logger = setup_logger(f"adaptive_log/log_char_{index}.log")
160+
logger = setup_logger(f"adaptive_log/origin_result/10_rate/log_char_{index}.log")
141161
logger.info(f"sigma: {sigma}")
142162
logger.info(f"num_inference_steps: {num_inference_steps}")
143163
logger.info(f"num_batch: {num_batch}")
144164
logger.info(f"batch_size: {batch_size}")
145165
logger.info(AEdata_path)
146166
logger.info(f"ori_prompt: {ori_prompt}")
147-
df = pd.read_csv(AEdata_path)
148167
ori_loss = []
149168
for i in range(num_batch):
150-
generator = torch.Generator(device).manual_seed(1023+i)
151-
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
169+
# generator = torch.Generator(device).manual_seed(1023+i)
170+
# images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
171+
images = generate_func(pipe, ori_prompt, seed=1023+i)
152172
for j in range(batch_size):
153173
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
154174
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
155175
logger.info(f"*" * 120)
156-
for id in range(2, 3):
157-
efficient_n = 0
158-
Non_AE, n = 0, 0
159-
L_distance, AdvSt2i = [], []
160-
robust_re, epsilon_re = [], []
161-
sample_data = list(df[f"Column {id}"].dropna())
162-
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
163-
logger.info(f"disturb rate: {id}")
164-
logger.info(f"disturb_num: {sample_data[0]}")
165-
n = 1
166-
epsilon = 1000
167-
for count in range(sample_num):
168-
disturb_prompt = random.choices(strings, k=1)[0]
169-
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
170-
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
171-
Non_AE += 1 if whether_robust else 0
172-
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
173-
robust_re.append((robust_left, robust_right))
174-
epsilon_re.append(epsilon)
175-
logger.info(f"stop_early: {efficient_n}")
176-
logger.info(f"futility: {efficient_m}")
177-
logger.info(f"Non_AE: {Non_AE}")
178-
logger.info(f"n: {n}")
179-
logger.info(f"robust reach: {robust_left} , {robust_right}")
180-
logger.info(f"epsilon reach: {epsilon}")
181-
print("*" * 120)
182-
logger.info(f"*" * 120)
183-
n += 1
176+
efficient_n = 0
177+
Non_AE, n = 0, 0
178+
L_distance, AdvSt2i = [], []
179+
robust_re, epsilon_re = [], []
180+
n = 1
181+
epsilon = 1000
182+
for item in content[1:]:
183+
disturb_prompt = item.strip()
184+
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
185+
Non_AE += 1 if whether_robust else 0
186+
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
187+
robust_re.append((robust_left, robust_right))
188+
epsilon_re.append(epsilon)
189+
logger.info(f"stop_early: {efficient_n}")
190+
logger.info(f"futility: {efficient_m}")
191+
logger.info(f"Non_AE: {Non_AE}")
192+
logger.info(f"n: {n}")
193+
logger.info(f"robust reach: {robust_left} , {robust_right}")
194+
logger.info(f"epsilon reach: {epsilon}")
184195
print("*" * 120)
185196
logger.info(f"*" * 120)
186-
logger.info(f"robust = {robust_re}")
187-
logger.info(f"epsilon = {epsilon_re}")
188-
logger.info(f"stop_early = {efficient_n}")
189-
logger.info(f"futility = {efficient_m}")
190-
logger.info(f"Non_AE = {Non_AE}")
191-
logger.info(f"n = {n}")
192-
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
193-
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
194-
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
195-
logger.info(f"robust = {robust_left} , {robust_right}")
196-
logger.info(f"epsilon = {epsilon}")
197-
198-
end_time = time.time()
199-
elapsed_time = end_time - start_time
200-
hours, remainder = divmod(elapsed_time, 3600)
201-
minutes, seconds = divmod(remainder, 60)
202-
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
203-
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
204-
logger.info(f"&" * 150)
205-
206-
207-
208-
209-
197+
n += 1
198+
print("*" * 120)
199+
logger.info(f"*" * 120)
200+
logger.info(f"robust = {robust_re}")
201+
logger.info(f"epsilon = {epsilon_re}")
202+
logger.info(f"stop_early = {efficient_n}")
203+
logger.info(f"futility = {efficient_m}")
204+
logger.info(f"Non_AE = {Non_AE}")
205+
logger.info(f"n = {n}")
206+
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
207+
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
208+
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
209+
logger.info(f"robust = {robust_left} , {robust_right}")
210+
logger.info(f"epsilon = {epsilon}")
211+
212+
end_time = time.time()
213+
elapsed_time = end_time - start_time
214+
hours, remainder = divmod(elapsed_time, 3600)
215+
minutes, seconds = divmod(remainder, 60)
216+
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
217+
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
218+
logger.info(f"&" * 150)

0 commit comments

Comments
 (0)