1
1
import os
2
- os .environ ["CUDA_VISIBLE_DEVICES" ] = "1 "
2
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = "0 "
3
3
import torch
4
4
device = "cuda" if torch .cuda .is_available () else "cpu"
5
5
import logging
16
16
import pandas as pd
17
17
torch .cuda .empty_cache ()
18
18
from config import (
19
- sample_num , origin_prompt_path , sigma ,
19
+ e_threshold , origin_prompt_path , sigma , clip_version ,
20
20
num_inference_steps , num_batch , batch_size ,
21
21
model_id
22
22
)
23
23
import random
24
- # random.seed(42)
24
+
25
25
26
26
def setup_logger (file_name ):
27
27
logger = logging .getLogger (file_name )
@@ -35,12 +35,14 @@ def setup_logger(file_name):
35
35
return logger
36
36
37
37
pipe = StableDiffusionPipeline .from_pretrained (model_id , torch_dtype = torch .float16 )
38
+ # pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=torch.float16)
38
39
pipe = pipe .to (device )
39
40
40
41
transform = transforms .Compose ([transforms .ToTensor ()])
41
- metric = CLIPScore ().to (device )
42
+ metric = CLIPScore (model_name_or_path = clip_version ).to (device )
42
43
43
44
def calculate_text_image_distance (text , image ):
45
+ # print(transform(image).shape) # torch.Size([1, 512, 512])
44
46
img = transform (image )* 255
45
47
score = metric (img .to (device ), text )
46
48
return score .detach ().cpu ().numpy ().item ()
@@ -62,6 +64,17 @@ def calculate_t_zValue(data1, data2):
62
64
return p_val , z_value
63
65
64
66
67
+ def generate_func (pipe , prompt , seed ):
68
+ while True :
69
+ generator = torch .Generator (device = "cuda" ).manual_seed (seed )
70
+ output = pipe ([prompt ] * batch_size , num_inference_steps = num_inference_steps , generator = generator )
71
+ print (output .nsfw_content_detected , seed )
72
+ if any (output .nsfw_content_detected ):
73
+ print (f"Potential NSFW content was detected with seed { seed } . Trying a new seed." )
74
+ seed = seed + 100
75
+ else :
76
+ return output
77
+
65
78
def cal_loss (ori_loss , disturb_prompt , ori_prompt ):
66
79
print ("dis_prompt" , disturb_prompt )
67
80
print ("ori_prompt" , ori_prompt )
@@ -71,18 +84,20 @@ def cal_loss(ori_loss, disturb_prompt, ori_prompt):
71
84
dis_interim_loss = []
72
85
alpha = [0.0148 , 0.0262 , 0.0354 , 0.0432 , 0.05 ]
73
86
futility_boundary = [- 0.145 , 0.511 , 1.027 , 1.497 , float ('inf' )]
74
- efficacy_boundary = [2.176 , 2.144 , 2.113 , 2.090 , 2.071 ]
75
- for i in range (num_batch ):
87
+ efficacy_boundary = [2.176 , 2.144 , 2.113 , 2.090 , 2.071 ] # critical value when Z-score > critical value, the reject the H0
88
+ for i in range (num_batch ): # 5
76
89
ori_interim_loss = ori_loss [0 : batch_size * (i + 1 )]
77
- generator = torch .Generator (device ).manual_seed (1023 + i )
78
- images = pipe ([disturb_prompt ] * batch_size , num_inference_steps = num_inference_steps , generator = generator )
79
- for j in range (batch_size ):
90
+ # generator = torch.Generator(device).manual_seed(1023+i)
91
+ # images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
92
+ images = generate_func (pipe , disturb_prompt , seed = 1023 + i )
93
+ for j in range (batch_size ): # 5
80
94
dis_interim_loss .append (calculate_text_image_distance (ori_prompt , images .images [j ]))
95
+ # _, p_val = mannwhitneyu(dis_interim_loss, ori_interim_loss)
81
96
logger .info (f"dis_interim_loss: { len (dis_interim_loss )} ; { dis_interim_loss } " )
82
97
logger .info (f"ori_interim_loss: { len (ori_interim_loss )} ; { ori_interim_loss } " )
83
98
_ , p_1 = shapiro (ori_interim_loss [0 : 12 * (i + 1 )])
84
99
_ , p_2 = shapiro (dis_interim_loss [0 : 12 * (i + 1 )])
85
- if p_1 > 0.05 and p_2 > 0.05 :
100
+ if p_1 > 0.05 and p_2 > 0.05 : # normal distr
86
101
p_val , z_val = calculate_t_zValue (ori_interim_loss [0 : 12 * (i + 1 )], dis_interim_loss [0 : 12 * (i + 1 )])
87
102
else :
88
103
p_val , z_val = calculate_u_zValue (ori_interim_loss [0 : 12 * (i + 1 )], dis_interim_loss [0 : 12 * (i + 1 )])
@@ -107,6 +122,7 @@ def get_AE(sample_data):
107
122
import random
108
123
random .seed (42 )
109
124
strings = [line .split (':' )[0 ].strip () for line in sample_data [1 :]]
125
+ # sampled_strings = random.sample(strings, len(strings))
110
126
sampled_strings = random .choices (strings , k = 1 )
111
127
return sampled_strings
112
128
@@ -130,80 +146,73 @@ def get_origin_prompt(origin_prompt_path):
130
146
return origin_prompt
131
147
132
148
149
+
133
150
if __name__ == "__main__" :
134
151
start_time = time .time ()
135
152
robust_left , robust_right = 0 , 0
136
- origin_prompts = get_origin_prompt (origin_prompt_path )
137
- for index , ori_prompt in origin_prompts .items ():
153
+ test_path = "dataset/"
154
+ for index in range (1 , 100 ):
155
+ AEdata_path = f"{ test_path } /{ index } .txt"
156
+ with open (AEdata_path , 'r' , encoding = 'utf-8' ) as file :
157
+ content = file .readlines ()
158
+ ori_prompt = content [0 ].strip ()
138
159
efficient_m , efficient_n = 0 , 0
139
- AEdata_path = f"./generate_AE/coco/char_AE/result_{ index } .csv"
140
- logger = setup_logger (f"adaptive_log/log_char_{ index } .log" )
160
+ logger = setup_logger (f"adaptive_log/origin_result/10_rate/log_char_{ index } .log" )
141
161
logger .info (f"sigma: { sigma } " )
142
162
logger .info (f"num_inference_steps: { num_inference_steps } " )
143
163
logger .info (f"num_batch: { num_batch } " )
144
164
logger .info (f"batch_size: { batch_size } " )
145
165
logger .info (AEdata_path )
146
166
logger .info (f"ori_prompt: { ori_prompt } " )
147
- df = pd .read_csv (AEdata_path )
148
167
ori_loss = []
149
168
for i in range (num_batch ):
150
- generator = torch .Generator (device ).manual_seed (1023 + i )
151
- images = pipe ([ori_prompt ] * batch_size , num_inference_steps = num_inference_steps , generator = generator )
169
+ # generator = torch.Generator(device).manual_seed(1023+i)
170
+ # images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
171
+ images = generate_func (pipe , ori_prompt , seed = 1023 + i )
152
172
for j in range (batch_size ):
153
173
ori_loss .append (calculate_text_image_distance (ori_prompt , images .images [j ]))
154
174
logger .info (f"ori_loss: { len (ori_loss )} { ori_loss } " )
155
175
logger .info (f"*" * 120 )
156
- for id in range (2 , 3 ):
157
- efficient_n = 0
158
- Non_AE , n = 0 , 0
159
- L_distance , AdvSt2i = [], []
160
- robust_re , epsilon_re = [], []
161
- sample_data = list (df [f"Column { id } " ].dropna ())
162
- strings = [line .split (':' )[0 ].strip () for line in sample_data [1 :]]
163
- logger .info (f"disturb rate: { id } " )
164
- logger .info (f"disturb_num: { sample_data [0 ]} " )
165
- n = 1
166
- epsilon = 1000
167
- for count in range (sample_num ):
168
- disturb_prompt = random .choices (strings , k = 1 )[0 ]
169
- L_distance .append (Levenshtein .distance (ori_prompt , disturb_prompt ))
170
- whether_robust = cal_loss (ori_loss , disturb_prompt , ori_prompt )
171
- Non_AE += 1 if whether_robust else 0
172
- robust_left , robust_right , epsilon = calculate_R (Non_AE , n )
173
- robust_re .append ((robust_left , robust_right ))
174
- epsilon_re .append (epsilon )
175
- logger .info (f"stop_early: { efficient_n } " )
176
- logger .info (f"futility: { efficient_m } " )
177
- logger .info (f"Non_AE: { Non_AE } " )
178
- logger .info (f"n: { n } " )
179
- logger .info (f"robust reach: { robust_left } , { robust_right } " )
180
- logger .info (f"epsilon reach: { epsilon } " )
181
- print ("*" * 120 )
182
- logger .info (f"*" * 120 )
183
- n += 1
176
+ efficient_n = 0
177
+ Non_AE , n = 0 , 0
178
+ L_distance , AdvSt2i = [], []
179
+ robust_re , epsilon_re = [], []
180
+ n = 1
181
+ epsilon = 1000
182
+ for item in content [1 :]:
183
+ disturb_prompt = item .strip ()
184
+ whether_robust = cal_loss (ori_loss , disturb_prompt , ori_prompt )
185
+ Non_AE += 1 if whether_robust else 0
186
+ robust_left , robust_right , epsilon = calculate_R (Non_AE , n )
187
+ robust_re .append ((robust_left , robust_right ))
188
+ epsilon_re .append (epsilon )
189
+ logger .info (f"stop_early: { efficient_n } " )
190
+ logger .info (f"futility: { efficient_m } " )
191
+ logger .info (f"Non_AE: { Non_AE } " )
192
+ logger .info (f"n: { n } " )
193
+ logger .info (f"robust reach: { robust_left } , { robust_right } " )
194
+ logger .info (f"epsilon reach: { epsilon } " )
184
195
print ("*" * 120 )
185
196
logger .info (f"*" * 120 )
186
- logger .info (f"robust = { robust_re } " )
187
- logger .info (f"epsilon = { epsilon_re } " )
188
- logger .info (f"stop_early = { efficient_n } " )
189
- logger .info (f"futility = { efficient_m } " )
190
- logger .info (f"Non_AE = { Non_AE } " )
191
- logger .info (f"n = { n } " )
192
- logger .info (f"AdvSt2i = { round (np .mean (AdvSt2i ), 2 )} " )
193
- logger .info (f"OriSt2i = { round (np .mean (ori_loss ), 2 )} " )
194
- logger .info (f"Levenshtein = { round (np .mean (L_distance ), 2 )} " )
195
- logger .info (f"robust = { robust_left } , { robust_right } " )
196
- logger .info (f"epsilon = { epsilon } " )
197
-
198
- end_time = time .time ()
199
- elapsed_time = end_time - start_time
200
- hours , remainder = divmod (elapsed_time , 3600 )
201
- minutes , seconds = divmod (remainder , 60 )
202
- print (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
203
- logger .info (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
204
- logger .info (f"&" * 150 )
205
-
206
-
207
-
208
-
209
-
197
+ n += 1
198
+ print ("*" * 120 )
199
+ logger .info (f"*" * 120 )
200
+ logger .info (f"robust = { robust_re } " )
201
+ logger .info (f"epsilon = { epsilon_re } " )
202
+ logger .info (f"stop_early = { efficient_n } " )
203
+ logger .info (f"futility = { efficient_m } " )
204
+ logger .info (f"Non_AE = { Non_AE } " )
205
+ logger .info (f"n = { n } " )
206
+ logger .info (f"AdvSt2i = { round (np .mean (AdvSt2i ), 2 )} " )
207
+ logger .info (f"OriSt2i = { round (np .mean (ori_loss ), 2 )} " )
208
+ logger .info (f"Levenshtein = { round (np .mean (L_distance ), 2 )} " )
209
+ logger .info (f"robust = { robust_left } , { robust_right } " )
210
+ logger .info (f"epsilon = { epsilon } " )
211
+
212
+ end_time = time .time ()
213
+ elapsed_time = end_time - start_time
214
+ hours , remainder = divmod (elapsed_time , 3600 )
215
+ minutes , seconds = divmod (remainder , 60 )
216
+ print (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
217
+ logger .info (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
218
+ logger .info (f"&" * 150 )
0 commit comments