Skip to content

Commit c893ac9

Browse files
author
Yi Zhang
committed
update
1 parent b5e4b6a commit c893ac9

File tree

5 files changed

+141
-152
lines changed

5 files changed

+141
-152
lines changed

defence/gramformer_gec.py

-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
'''
2-
The Gramformer project for GEC
3-
4-
https://github.com/PrithivirajDamodaran/Gramformer
5-
6-
pip install -U git+https://github.com/PrithivirajDamodaran/Gramformer.git
7-
python -m spacy download en_core_web_sm
8-
9-
'''
101
from gramformer import Gramformer
112
import torch
123
from tqdm import tqdm

generate_AE/attack_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from textattack.augmentation import Augmenter
2020
# transformation = CompositeTransformation([WordSwapRandomCharacterInsertion(),WordSwapRandomCharacterSubstitution(),WordSwapRandomCharacterDeletion(),WordSwapNeighboringCharacterSwap(),WordSwapQWERTY()])
2121
transformation = CompositeTransformation([WordSwapQWERTY()])
22-
constraints = [RepeatModification()]
22+
constraints = [RepeatModification(),StopwordModification()]
2323

2424
import pandas as pd
2525
import clip

generate_AE/char_level.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from textattack.constraints.pre_transformation import RepeatModification
1818
from textattack.constraints.pre_transformation import StopwordModification
1919
from textattack.augmentation import Augmenter
20-
transformation = CompositeTransformation([WordSwapRandomCharacterInsertion(),WordSwapRandomCharacterSubstitution(),WordSwapRandomCharacterDeletion(),WordSwapNeighboringCharacterSwap(),WordSwapQWERTY()])
20+
# transformation = CompositeTransformation([WordSwapRandomCharacterInsertion(),WordSwapRandomCharacterSubstitution(),WordSwapRandomCharacterDeletion(),WordSwapNeighboringCharacterSwap(),WordSwapQWERTY()])
2121
# transformation = CompositeTransformation([WordSwapRandomCharacterSubstitution()])
22-
constraints = [RepeatModification()]
22+
constraints = [RepeatModification(), StopwordModification()]
2323

2424
import pandas as pd
2525
import clip

version_defence.py

+73-74
Original file line numberDiff line numberDiff line change
@@ -155,81 +155,80 @@ def defence_autocorrect(influent_sentence):
155155
start_time = time.time()
156156
origin_prompts = get_origin_prompt(origin_prompt_path)
157157
for index, ori_prompt in origin_prompts.items():
158-
if index == 19:
159-
AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
160-
logger = setup_logger(f"adaptive_log/coco_update/spellchecker/log_char_{index}.log")
161-
logger.info(f"sigma: {sigma}")
162-
logger.info(f"num_inference_steps: {num_inference_steps}")
163-
logger.info(f"num_batch: {num_batch}")
164-
logger.info(f"batch_size: {batch_size}")
165-
logger.info(AEdata_path)
166-
logger.info(f"ori_prompt: {ori_prompt}")
167-
df = pd.read_csv(AEdata_path)
168-
ori_loss = []
169-
for i in range(num_batch):
170-
generator = torch.Generator(device).manual_seed(1023+i)
171-
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
172-
for j in range(batch_size):
173-
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
174-
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
175-
logger.info(f"*" * 120)
176-
for id in range(1, 2):
177-
efficient_n = 0
178-
Non_AE, n = 0, 0
179-
L_distance, AdvSt2i = [], []
180-
robust_re, epsilon_re = [], []
181-
sample_data = list(df[f"Column {id}"].dropna())
182-
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
183-
logger.info(f"disturb rate: {id}")
184-
logger.info(f"disturb_num: {sample_data[0]}")
185-
n = 1
186-
epsilon = 1000
187-
for count in range(sample_num):
188-
selected = random.choices(strings, k=1)[0]
189-
disturb_prompt = defence_spellchecker(selected)
190-
if disturb_prompt == ori_prompt:
191-
Non_AE += 1
192-
logger.info(f"dis_prompt: {selected}")
193-
logger.info(f"ori_prompt: {ori_prompt}")
194-
logger.info(f"same")
195-
else:
196-
logger.info(f"unsame")
197-
logger.info(f"selected: {selected}")
198-
logger.info(f"revised: {disturb_prompt}")
199-
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
200-
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
201-
Non_AE += 1 if whether_robust else 0
202-
203-
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
204-
robust_re.append((robust_left, robust_right))
205-
epsilon_re.append(epsilon)
206-
logger.info(f"stop_early: {efficient_n}")
207-
logger.info(f"Non_AE: {Non_AE}")
208-
logger.info(f"n: {n}")
209-
logger.info(f"robust reach: {robust_left} , {robust_right}")
210-
logger.info(f"epsilon reach: {epsilon}")
211-
print("*" * 120)
212-
logger.info(f"*" * 120)
213-
n += 1
158+
AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
159+
logger = setup_logger(f"adaptive_log/coco_update/spellchecker/log_char_{index}.log")
160+
logger.info(f"sigma: {sigma}")
161+
logger.info(f"num_inference_steps: {num_inference_steps}")
162+
logger.info(f"num_batch: {num_batch}")
163+
logger.info(f"batch_size: {batch_size}")
164+
logger.info(AEdata_path)
165+
logger.info(f"ori_prompt: {ori_prompt}")
166+
df = pd.read_csv(AEdata_path)
167+
ori_loss = []
168+
for i in range(num_batch):
169+
generator = torch.Generator(device).manual_seed(1023+i)
170+
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
171+
for j in range(batch_size):
172+
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
173+
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
174+
logger.info(f"*" * 120)
175+
for id in range(1, 2):
176+
efficient_n = 0
177+
Non_AE, n = 0, 0
178+
L_distance, AdvSt2i = [], []
179+
robust_re, epsilon_re = [], []
180+
sample_data = list(df[f"Column {id}"].dropna())
181+
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
182+
logger.info(f"disturb rate: {id}")
183+
logger.info(f"disturb_num: {sample_data[0]}")
184+
n = 1
185+
epsilon = 1000
186+
for count in range(sample_num):
187+
selected = random.choices(strings, k=1)[0]
188+
disturb_prompt = defence_spellchecker(selected)
189+
if disturb_prompt == ori_prompt:
190+
Non_AE += 1
191+
logger.info(f"dis_prompt: {selected}")
192+
logger.info(f"ori_prompt: {ori_prompt}")
193+
logger.info(f"same")
194+
else:
195+
logger.info(f"unsame")
196+
logger.info(f"selected: {selected}")
197+
logger.info(f"revised: {disturb_prompt}")
198+
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
199+
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
200+
Non_AE += 1 if whether_robust else 0
201+
202+
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
203+
robust_re.append((robust_left, robust_right))
204+
epsilon_re.append(epsilon)
205+
logger.info(f"stop_early: {efficient_n}")
206+
logger.info(f"Non_AE: {Non_AE}")
207+
logger.info(f"n: {n}")
208+
logger.info(f"robust reach: {robust_left} , {robust_right}")
209+
logger.info(f"epsilon reach: {epsilon}")
214210
print("*" * 120)
215211
logger.info(f"*" * 120)
216-
logger.info(f"robust = {robust_re}")
217-
logger.info(f"epsilon = {epsilon_re}")
218-
logger.info(f"stop_early = {efficient_n}")
219-
logger.info(f"E_n = {Non_AE}")
220-
logger.info(f"n = {n}")
221-
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
222-
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
223-
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
224-
logger.info(f"robust = {robust_left} , {robust_right}")
225-
logger.info(f"epsilon = {epsilon}")
226-
227-
end_time = time.time()
228-
elapsed_time = end_time - start_time
229-
hours, remainder = divmod(elapsed_time, 3600)
230-
minutes, seconds = divmod(remainder, 60)
231-
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
232-
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
233-
logger.info(f"&" * 150)
212+
n += 1
213+
print("*" * 120)
214+
logger.info(f"*" * 120)
215+
logger.info(f"robust = {robust_re}")
216+
logger.info(f"epsilon = {epsilon_re}")
217+
logger.info(f"stop_early = {efficient_n}")
218+
logger.info(f"E_n = {Non_AE}")
219+
logger.info(f"n = {n}")
220+
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
221+
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
222+
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
223+
logger.info(f"robust = {robust_left} , {robust_right}")
224+
logger.info(f"epsilon = {epsilon}")
225+
226+
end_time = time.time()
227+
elapsed_time = end_time - start_time
228+
hours, remainder = divmod(elapsed_time, 3600)
229+
minutes, seconds = divmod(remainder, 60)
230+
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
231+
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
232+
logger.info(f"&" * 150)
234233

235234

version_sample.py

+65-66
Original file line numberDiff line numberDiff line change
@@ -135,74 +135,73 @@ def get_origin_prompt(origin_prompt_path):
135135
robust_left, robust_right = 0, 0
136136
origin_prompts = get_origin_prompt(origin_prompt_path)
137137
for index, ori_prompt in origin_prompts.items():
138-
if index >= 7 and index < 31:
139-
efficient_m, efficient_n = 0, 0
140-
AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
141-
logger = setup_logger(f"adaptive_log/log_char_{index}.log")
142-
logger.info(f"sigma: {sigma}")
143-
logger.info(f"num_inference_steps: {num_inference_steps}")
144-
logger.info(f"num_batch: {num_batch}")
145-
logger.info(f"batch_size: {batch_size}")
146-
logger.info(AEdata_path)
147-
logger.info(f"ori_prompt: {ori_prompt}")
148-
df = pd.read_csv(AEdata_path)
149-
ori_loss = []
150-
for i in range(num_batch):
151-
generator = torch.Generator(device).manual_seed(1023+i)
152-
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
153-
for j in range(batch_size):
154-
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
155-
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
156-
logger.info(f"*" * 120)
157-
for id in range(2, 3):
158-
efficient_n = 0
159-
Non_AE, n = 0, 0
160-
L_distance, AdvSt2i = [], []
161-
robust_re, epsilon_re = [], []
162-
sample_data = list(df[f"Column {id}"].dropna())
163-
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
164-
logger.info(f"disturb rate: {id}")
165-
logger.info(f"disturb_num: {sample_data[0]}")
166-
n = 1
167-
epsilon = 1000
168-
for count in range(sample_num):
169-
disturb_prompt = random.choices(strings, k=1)[0]
170-
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
171-
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
172-
Non_AE += 1 if whether_robust else 0
173-
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
174-
robust_re.append((robust_left, robust_right))
175-
epsilon_re.append(epsilon)
176-
logger.info(f"stop_early: {efficient_n}")
177-
logger.info(f"futility: {efficient_m}")
178-
logger.info(f"Non_AE: {Non_AE}")
179-
logger.info(f"n: {n}")
180-
logger.info(f"robust reach: {robust_left} , {robust_right}")
181-
logger.info(f"epsilon reach: {epsilon}")
182-
print("*" * 120)
183-
logger.info(f"*" * 120)
184-
n += 1
138+
efficient_m, efficient_n = 0, 0
139+
AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
140+
logger = setup_logger(f"adaptive_log/log_char_{index}.log")
141+
logger.info(f"sigma: {sigma}")
142+
logger.info(f"num_inference_steps: {num_inference_steps}")
143+
logger.info(f"num_batch: {num_batch}")
144+
logger.info(f"batch_size: {batch_size}")
145+
logger.info(AEdata_path)
146+
logger.info(f"ori_prompt: {ori_prompt}")
147+
df = pd.read_csv(AEdata_path)
148+
ori_loss = []
149+
for i in range(num_batch):
150+
generator = torch.Generator(device).manual_seed(1023+i)
151+
images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
152+
for j in range(batch_size):
153+
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
154+
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
155+
logger.info(f"*" * 120)
156+
for id in range(2, 3):
157+
efficient_n = 0
158+
Non_AE, n = 0, 0
159+
L_distance, AdvSt2i = [], []
160+
robust_re, epsilon_re = [], []
161+
sample_data = list(df[f"Column {id}"].dropna())
162+
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
163+
logger.info(f"disturb rate: {id}")
164+
logger.info(f"disturb_num: {sample_data[0]}")
165+
n = 1
166+
epsilon = 1000
167+
for count in range(sample_num):
168+
disturb_prompt = random.choices(strings, k=1)[0]
169+
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
170+
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
171+
Non_AE += 1 if whether_robust else 0
172+
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
173+
robust_re.append((robust_left, robust_right))
174+
epsilon_re.append(epsilon)
175+
logger.info(f"stop_early: {efficient_n}")
176+
logger.info(f"futility: {efficient_m}")
177+
logger.info(f"Non_AE: {Non_AE}")
178+
logger.info(f"n: {n}")
179+
logger.info(f"robust reach: {robust_left} , {robust_right}")
180+
logger.info(f"epsilon reach: {epsilon}")
185181
print("*" * 120)
186182
logger.info(f"*" * 120)
187-
logger.info(f"robust = {robust_re}")
188-
logger.info(f"epsilon = {epsilon_re}")
189-
logger.info(f"stop_early = {efficient_n}")
190-
logger.info(f"futility = {efficient_m}")
191-
logger.info(f"Non_AE = {Non_AE}")
192-
logger.info(f"n = {n}")
193-
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
194-
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
195-
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
196-
logger.info(f"robust = {robust_left} , {robust_right}")
197-
logger.info(f"epsilon = {epsilon}")
198-
199-
end_time = time.time()
200-
elapsed_time = end_time - start_time
201-
hours, remainder = divmod(elapsed_time, 3600)
202-
minutes, seconds = divmod(remainder, 60)
203-
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
204-
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
205-
logger.info(f"&" * 150)
183+
n += 1
184+
print("*" * 120)
185+
logger.info(f"*" * 120)
186+
logger.info(f"robust = {robust_re}")
187+
logger.info(f"epsilon = {epsilon_re}")
188+
logger.info(f"stop_early = {efficient_n}")
189+
logger.info(f"futility = {efficient_m}")
190+
logger.info(f"Non_AE = {Non_AE}")
191+
logger.info(f"n = {n}")
192+
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
193+
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
194+
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
195+
logger.info(f"robust = {robust_left} , {robust_right}")
196+
logger.info(f"epsilon = {epsilon}")
197+
198+
end_time = time.time()
199+
elapsed_time = end_time - start_time
200+
hours, remainder = divmod(elapsed_time, 3600)
201+
minutes, seconds = divmod(remainder, 60)
202+
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
203+
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
204+
logger.info(f"&" * 150)
206205

207206

208207

0 commit comments

Comments
 (0)