@@ -155,81 +155,80 @@ def defence_autocorrect(influent_sentence):
155
155
start_time = time .time ()
156
156
origin_prompts = get_origin_prompt (origin_prompt_path )
157
157
for index , ori_prompt in origin_prompts .items ():
158
- if index == 19 :
159
- AEdata_path = f"./generate_AE/coco/char_AE/result_{ index } .csv"
160
- logger = setup_logger (f"adaptive_log/coco_update/spellchecker/log_char_{ index } .log" )
161
- logger .info (f"sigma: { sigma } " )
162
- logger .info (f"num_inference_steps: { num_inference_steps } " )
163
- logger .info (f"num_batch: { num_batch } " )
164
- logger .info (f"batch_size: { batch_size } " )
165
- logger .info (AEdata_path )
166
- logger .info (f"ori_prompt: { ori_prompt } " )
167
- df = pd .read_csv (AEdata_path )
168
- ori_loss = []
169
- for i in range (num_batch ):
170
- generator = torch .Generator (device ).manual_seed (1023 + i )
171
- images = pipe ([ori_prompt ] * batch_size , num_inference_steps = num_inference_steps , generator = generator )
172
- for j in range (batch_size ):
173
- ori_loss .append (calculate_text_image_distance (ori_prompt , images .images [j ]))
174
- logger .info (f"ori_loss: { len (ori_loss )} { ori_loss } " )
175
- logger .info (f"*" * 120 )
176
- for id in range (1 , 2 ):
177
- efficient_n = 0
178
- Non_AE , n = 0 , 0
179
- L_distance , AdvSt2i = [], []
180
- robust_re , epsilon_re = [], []
181
- sample_data = list (df [f"Column { id } " ].dropna ())
182
- strings = [line .split (':' )[0 ].strip () for line in sample_data [1 :]]
183
- logger .info (f"disturb rate: { id } " )
184
- logger .info (f"disturb_num: { sample_data [0 ]} " )
185
- n = 1
186
- epsilon = 1000
187
- for count in range (sample_num ):
188
- selected = random .choices (strings , k = 1 )[0 ]
189
- disturb_prompt = defence_spellchecker (selected )
190
- if disturb_prompt == ori_prompt :
191
- Non_AE += 1
192
- logger .info (f"dis_prompt: { selected } " )
193
- logger .info (f"ori_prompt: { ori_prompt } " )
194
- logger .info (f"same" )
195
- else :
196
- logger .info (f"unsame" )
197
- logger .info (f"selected: { selected } " )
198
- logger .info (f"revised: { disturb_prompt } " )
199
- L_distance .append (Levenshtein .distance (ori_prompt , disturb_prompt ))
200
- whether_robust = cal_loss (ori_loss , disturb_prompt , ori_prompt )
201
- Non_AE += 1 if whether_robust else 0
202
-
203
- robust_left , robust_right , epsilon = calculate_R (Non_AE , n )
204
- robust_re .append ((robust_left , robust_right ))
205
- epsilon_re .append (epsilon )
206
- logger .info (f"stop_early: { efficient_n } " )
207
- logger .info (f"Non_AE: { Non_AE } " )
208
- logger .info (f"n: { n } " )
209
- logger .info (f"robust reach: { robust_left } , { robust_right } " )
210
- logger .info (f"epsilon reach: { epsilon } " )
211
- print ("*" * 120 )
212
- logger .info (f"*" * 120 )
213
- n += 1
158
+ AEdata_path = f"./generate_AE/coco/char_AE/result_{ index } .csv"
159
+ logger = setup_logger (f"adaptive_log/coco_update/spellchecker/log_char_{ index } .log" )
160
+ logger .info (f"sigma: { sigma } " )
161
+ logger .info (f"num_inference_steps: { num_inference_steps } " )
162
+ logger .info (f"num_batch: { num_batch } " )
163
+ logger .info (f"batch_size: { batch_size } " )
164
+ logger .info (AEdata_path )
165
+ logger .info (f"ori_prompt: { ori_prompt } " )
166
+ df = pd .read_csv (AEdata_path )
167
+ ori_loss = []
168
+ for i in range (num_batch ):
169
+ generator = torch .Generator (device ).manual_seed (1023 + i )
170
+ images = pipe ([ori_prompt ] * batch_size , num_inference_steps = num_inference_steps , generator = generator )
171
+ for j in range (batch_size ):
172
+ ori_loss .append (calculate_text_image_distance (ori_prompt , images .images [j ]))
173
+ logger .info (f"ori_loss: { len (ori_loss )} { ori_loss } " )
174
+ logger .info (f"*" * 120 )
175
+ for id in range (1 , 2 ):
176
+ efficient_n = 0
177
+ Non_AE , n = 0 , 0
178
+ L_distance , AdvSt2i = [], []
179
+ robust_re , epsilon_re = [], []
180
+ sample_data = list (df [f"Column { id } " ].dropna ())
181
+ strings = [line .split (':' )[0 ].strip () for line in sample_data [1 :]]
182
+ logger .info (f"disturb rate: { id } " )
183
+ logger .info (f"disturb_num: { sample_data [0 ]} " )
184
+ n = 1
185
+ epsilon = 1000
186
+ for count in range (sample_num ):
187
+ selected = random .choices (strings , k = 1 )[0 ]
188
+ disturb_prompt = defence_spellchecker (selected )
189
+ if disturb_prompt == ori_prompt :
190
+ Non_AE += 1
191
+ logger .info (f"dis_prompt: { selected } " )
192
+ logger .info (f"ori_prompt: { ori_prompt } " )
193
+ logger .info (f"same" )
194
+ else :
195
+ logger .info (f"unsame" )
196
+ logger .info (f"selected: { selected } " )
197
+ logger .info (f"revised: { disturb_prompt } " )
198
+ L_distance .append (Levenshtein .distance (ori_prompt , disturb_prompt ))
199
+ whether_robust = cal_loss (ori_loss , disturb_prompt , ori_prompt )
200
+ Non_AE += 1 if whether_robust else 0
201
+
202
+ robust_left , robust_right , epsilon = calculate_R (Non_AE , n )
203
+ robust_re .append ((robust_left , robust_right ))
204
+ epsilon_re .append (epsilon )
205
+ logger .info (f"stop_early: { efficient_n } " )
206
+ logger .info (f"Non_AE: { Non_AE } " )
207
+ logger .info (f"n: { n } " )
208
+ logger .info (f"robust reach: { robust_left } , { robust_right } " )
209
+ logger .info (f"epsilon reach: { epsilon } " )
214
210
print ("*" * 120 )
215
211
logger .info (f"*" * 120 )
216
- logger .info (f"robust = { robust_re } " )
217
- logger .info (f"epsilon = { epsilon_re } " )
218
- logger .info (f"stop_early = { efficient_n } " )
219
- logger .info (f"E_n = { Non_AE } " )
220
- logger .info (f"n = { n } " )
221
- logger .info (f"AdvSt2i = { round (np .mean (AdvSt2i ), 2 )} " )
222
- logger .info (f"OriSt2i = { round (np .mean (ori_loss ), 2 )} " )
223
- logger .info (f"Levenshtein = { round (np .mean (L_distance ), 2 )} " )
224
- logger .info (f"robust = { robust_left } , { robust_right } " )
225
- logger .info (f"epsilon = { epsilon } " )
226
-
227
- end_time = time .time ()
228
- elapsed_time = end_time - start_time
229
- hours , remainder = divmod (elapsed_time , 3600 )
230
- minutes , seconds = divmod (remainder , 60 )
231
- print (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
232
- logger .info (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
233
- logger .info (f"&" * 150 )
212
+ n += 1
213
+ print ("*" * 120 )
214
+ logger .info (f"*" * 120 )
215
+ logger .info (f"robust = { robust_re } " )
216
+ logger .info (f"epsilon = { epsilon_re } " )
217
+ logger .info (f"stop_early = { efficient_n } " )
218
+ logger .info (f"E_n = { Non_AE } " )
219
+ logger .info (f"n = { n } " )
220
+ logger .info (f"AdvSt2i = { round (np .mean (AdvSt2i ), 2 )} " )
221
+ logger .info (f"OriSt2i = { round (np .mean (ori_loss ), 2 )} " )
222
+ logger .info (f"Levenshtein = { round (np .mean (L_distance ), 2 )} " )
223
+ logger .info (f"robust = { robust_left } , { robust_right } " )
224
+ logger .info (f"epsilon = { epsilon } " )
225
+
226
+ end_time = time .time ()
227
+ elapsed_time = end_time - start_time
228
+ hours , remainder = divmod (elapsed_time , 3600 )
229
+ minutes , seconds = divmod (remainder , 60 )
230
+ print (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
231
+ logger .info (f"time cost: { int (hours )} hours, { int (minutes )} minutes, { int (seconds )} seconds" )
232
+ logger .info (f"&" * 150 )
234
233
235
234
0 commit comments