@@ -146,73 +146,80 @@ def get_origin_prompt(origin_prompt_path):
146
146
return origin_prompt
147
147
148
148
149
-
150
149
if __name__ == "__main__" :
151
150
start_time = time .time ()
152
151
robust_left , robust_right = 0 , 0
153
- test_path = "dataset/"
154
- for index in range (1 , 100 ):
155
- AEdata_path = f"{ test_path } /{ index } .txt"
156
- with open (AEdata_path , 'r' , encoding = 'utf-8' ) as file :
157
- content = file .readlines ()
158
- ori_prompt = content [0 ].strip ()
152
+ origin_prompts = get_origin_prompt (origin_prompt_path )
153
+ for index , ori_prompt in origin_prompts .items ():
159
154
efficient_m , efficient_n = 0 , 0
155
+ # AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
156
+ AEdata_path = f"./dataset4AT/test_in_train/result_{ index } .csv"
160
157
logger = setup_logger (f"adaptive_log/origin_result/10_rate/log_char_{ index } .log" )
161
158
logger .info (f"sigma: { sigma } " )
162
159
logger .info (f"num_inference_steps: { num_inference_steps } " )
163
160
logger .info (f"num_batch: { num_batch } " )
164
161
logger .info (f"batch_size: { batch_size } " )
165
162
logger .info (AEdata_path )
166
163
logger .info (f"ori_prompt: { ori_prompt } " )
164
+ df = pd .read_csv (AEdata_path )
167
165
ori_loss = []
168
166
for i in range (num_batch ):
169
167
# generator = torch.Generator(device).manual_seed(1023+i)
170
168
# images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
171
- images = generate_func (pipe , ori_prompt , seed = 1023 + i )
169
+ images = generate_func (pipe , ori_prompt , seed = 2023 + i )
172
170
for j in range (batch_size ):
173
171
ori_loss .append (calculate_text_image_distance (ori_prompt , images .images [j ]))
174
172
logger .info (f"ori_loss: { len (ori_loss )} { ori_loss } " )
175
173
logger .info (f"*" * 120 )
176
- efficient_n = 0
177
- Non_AE , n = 0 , 0
178
- L_distance , AdvSt2i = [], []
179
- robust_re , epsilon_re = [], []
180
- n = 1
181
- epsilon = 1000
182
- for item in content [1 :]:
183
- disturb_prompt = item .strip ()
184
- whether_robust = cal_loss (ori_loss , disturb_prompt , ori_prompt )
185
- Non_AE += 1 if whether_robust else 0
186
- robust_left , robust_right , epsilon = calculate_R (Non_AE , n )
187
- robust_re .append ((robust_left , robust_right ))
188
- epsilon_re .append (epsilon )
189
- logger .info (f"stop_early: { efficient_n } " )
190
- logger .info (f"futility: { efficient_m } " )
191
- logger .info (f"Non_AE: { Non_AE } " )
192
- logger .info (f"n: { n } " )
193
- logger .info (f"robust reach: { robust_left } , { robust_right } " )
194
- logger .info (f"epsilon reach: { epsilon } " )
174
+ for id in range (1 , 2 ): # focus on 10% perturb rate
175
+ efficient_n = 0
176
+ Non_AE , n = 0 , 0
177
+ L_distance , AdvSt2i = [], []
178
+ robust_re , epsilon_re = [], []
179
+ sample_data = list (df [f"Column { id } " ].dropna ())
180
+ strings = [line .split (':' )[0 ].strip () for line in sample_data [1 :]]
181
+ logger .info (f"disturb rate: { id } " )
182
+ logger .info (f"disturb_num: { sample_data [0 ]} " )
183
+ n = 1
184
+ epsilon = 1000
185
+ # while epsilon > e_threshold:
186
+ for count in range (400 ):
187
+ disturb_prompt = random .choices (strings , k = 1 )[0 ]
188
+ L_distance .append (Levenshtein .distance (ori_prompt , disturb_prompt ))
189
+ whether_robust = cal_loss (ori_loss , disturb_prompt , ori_prompt )
190
+ Non_AE += 1 if whether_robust else 0
191
+ # AdvSt2i.append(sum(dis_loss) / len(dis_loss))
192
+ robust_left , robust_right , epsilon = calculate_R (Non_AE , n )
193
+ robust_re .append ((robust_left , robust_right ))
194
+ epsilon_re .append (epsilon )
195
+ logger .info (f"stop_early: { efficient_n } " )
196
+ logger .info (f"futility: { efficient_m } " )
197
+ logger .info (f"Non_AE: { Non_AE } " )
198
+ logger .info (f"n: { n } " )
199
+ logger .info (f"robust reach: { robust_left } , { robust_right } " )
200
+ logger .info (f"epsilon reach: { epsilon } " )
201
+ print ("*" * 120 )
202
+ logger .info (f"*" * 120 )
203
+ n += 1
195
204
print ("*" * 120 )
196
205
logger .info (f"*" * 120 )
197
- n += 1
198
- print ("*" * 120 )
199
- logger .info (f"*" * 120 )
200
- logger .info (f"robust = { robust_re } " )
201
- logger .info (f"epsilon = { epsilon_re } " )
202
- logger .info (f"stop_early = { efficient_n } " )
203
- logger .info (f"futility = { efficient_m } " )
204
- logger .info (f"Non_AE = { Non_AE } " )
205
- logger .info (f"n = { n } " )
206
- logger .info (f"AdvSt2i = { round (np .mean (AdvSt2i ), 2 )} " )
207
- logger .info (f"OriSt2i = { round (np .mean (ori_loss ), 2 )} " )
208
- logger .info (f"Levenshtein = { round (np .mean (L_distance ), 2 )} " )
209
- logger .info (f"robust = { robust_left } , { robust_right } " )
210
- logger .info (f"epsilon = { epsilon } " )
211
-
212
- end_time = time .time ()
213
- elapsed_time = end_time - start_time
214
- hours , remainder = divmod (elapsed_time , 3600 )
215
- minutes , seconds = divmod (remainder , 60 )
216
- print (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
217
- logger .info (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
218
- logger .info (f"&" * 150 )
206
+ logger .info (f"robust = { robust_re } " )
207
+ logger .info (f"epsilon = { epsilon_re } " )
208
+ logger .info (f"stop_early = { efficient_n } " )
209
+ logger .info (f"futility = { efficient_m } " )
210
+ logger .info (f"Non_AE = { Non_AE } " )
211
+ logger .info (f"n = { n } " )
212
+ logger .info (f"AdvSt2i = { round (np .mean (AdvSt2i ), 2 )} " )
213
+ logger .info (f"OriSt2i = { round (np .mean (ori_loss ), 2 )} " )
214
+ logger .info (f"Levenshtein = { round (np .mean (L_distance ), 2 )} " )
215
+ logger .info (f"robust = { robust_left } , { robust_right } " )
216
+ logger .info (f"epsilon = { epsilon } " )
217
+
218
+ end_time = time .time ()
219
+ elapsed_time = end_time - start_time
220
+ hours , remainder = divmod (elapsed_time , 3600 )
221
+ minutes , seconds = divmod (remainder , 60 )
222
+ print (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
223
+ logger .info (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
224
+ logger .info (f"&" * 150 )
225
+
0 commit comments