Skip to content

Commit 3049479

Browse files
author
Yi Zhang
committed
update
1 parent 563a35e commit 3049479

File tree

1 file changed

+56
-49
lines changed

1 file changed

+56
-49
lines changed

version_sample.py

+56-49
Original file line numberDiff line numberDiff line change
@@ -146,73 +146,80 @@ def get_origin_prompt(origin_prompt_path):
146146
return origin_prompt
147147

148148

149-
150149
if __name__ == "__main__":
151150
start_time = time.time()
152151
robust_left, robust_right = 0, 0
153-
test_path = "dataset/"
154-
for index in range(1, 100):
155-
AEdata_path = f"{test_path}/{index}.txt"
156-
with open(AEdata_path, 'r', encoding = 'utf-8') as file:
157-
content = file.readlines()
158-
ori_prompt = content[0].strip()
152+
origin_prompts = get_origin_prompt(origin_prompt_path)
153+
for index, ori_prompt in origin_prompts.items():
159154
efficient_m, efficient_n = 0, 0
155+
# AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv"
156+
AEdata_path = f"./dataset4AT/test_in_train/result_{index}.csv"
160157
logger = setup_logger(f"adaptive_log/origin_result/10_rate/log_char_{index}.log")
161158
logger.info(f"sigma: {sigma}")
162159
logger.info(f"num_inference_steps: {num_inference_steps}")
163160
logger.info(f"num_batch: {num_batch}")
164161
logger.info(f"batch_size: {batch_size}")
165162
logger.info(AEdata_path)
166163
logger.info(f"ori_prompt: {ori_prompt}")
164+
df = pd.read_csv(AEdata_path)
167165
ori_loss = []
168166
for i in range(num_batch):
169167
# generator = torch.Generator(device).manual_seed(1023+i)
170168
# images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator)
171-
images = generate_func(pipe, ori_prompt, seed=1023+i)
169+
images = generate_func(pipe, ori_prompt, seed=2023+i)
172170
for j in range(batch_size):
173171
ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j]))
174172
logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}")
175173
logger.info(f"*" * 120)
176-
efficient_n = 0
177-
Non_AE, n = 0, 0
178-
L_distance, AdvSt2i = [], []
179-
robust_re, epsilon_re = [], []
180-
n = 1
181-
epsilon = 1000
182-
for item in content[1:]:
183-
disturb_prompt = item.strip()
184-
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
185-
Non_AE += 1 if whether_robust else 0
186-
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
187-
robust_re.append((robust_left, robust_right))
188-
epsilon_re.append(epsilon)
189-
logger.info(f"stop_early: {efficient_n}")
190-
logger.info(f"futility: {efficient_m}")
191-
logger.info(f"Non_AE: {Non_AE}")
192-
logger.info(f"n: {n}")
193-
logger.info(f"robust reach: {robust_left} , {robust_right}")
194-
logger.info(f"epsilon reach: {epsilon}")
174+
for id in range(1, 2): # focus on 10% perturb rate
175+
efficient_n = 0
176+
Non_AE, n = 0, 0
177+
L_distance, AdvSt2i = [], []
178+
robust_re, epsilon_re = [], []
179+
sample_data = list(df[f"Column {id}"].dropna())
180+
strings = [line.split(':')[0].strip() for line in sample_data[1:]]
181+
logger.info(f"disturb rate: {id}")
182+
logger.info(f"disturb_num: {sample_data[0]}")
183+
n = 1
184+
epsilon = 1000
185+
# while epsilon > e_threshold:
186+
for count in range(400):
187+
disturb_prompt = random.choices(strings, k=1)[0]
188+
L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt))
189+
whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt)
190+
Non_AE += 1 if whether_robust else 0
191+
# AdvSt2i.append(sum(dis_loss) / len(dis_loss))
192+
robust_left, robust_right, epsilon = calculate_R(Non_AE, n)
193+
robust_re.append((robust_left, robust_right))
194+
epsilon_re.append(epsilon)
195+
logger.info(f"stop_early: {efficient_n}")
196+
logger.info(f"futility: {efficient_m}")
197+
logger.info(f"Non_AE: {Non_AE}")
198+
logger.info(f"n: {n}")
199+
logger.info(f"robust reach: {robust_left} , {robust_right}")
200+
logger.info(f"epsilon reach: {epsilon}")
201+
print("*" * 120)
202+
logger.info(f"*" * 120)
203+
n += 1
195204
print("*" * 120)
196205
logger.info(f"*" * 120)
197-
n += 1
198-
print("*" * 120)
199-
logger.info(f"*" * 120)
200-
logger.info(f"robust = {robust_re}")
201-
logger.info(f"epsilon = {epsilon_re}")
202-
logger.info(f"stop_early = {efficient_n}")
203-
logger.info(f"futility = {efficient_m}")
204-
logger.info(f"Non_AE = {Non_AE}")
205-
logger.info(f"n = {n}")
206-
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
207-
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
208-
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
209-
logger.info(f"robust = {robust_left} , {robust_right}")
210-
logger.info(f"epsilon = {epsilon}")
211-
212-
end_time = time.time()
213-
elapsed_time = end_time - start_time
214-
hours, remainder = divmod(elapsed_time, 3600)
215-
minutes, seconds = divmod(remainder, 60)
216-
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
217-
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
218-
logger.info(f"&" * 150)
206+
logger.info(f"robust = {robust_re}")
207+
logger.info(f"epsilon = {epsilon_re}")
208+
logger.info(f"stop_early = {efficient_n}")
209+
logger.info(f"futility = {efficient_m}")
210+
logger.info(f"Non_AE = {Non_AE}")
211+
logger.info(f"n = {n}")
212+
logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}")
213+
logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}")
214+
logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}")
215+
logger.info(f"robust = {robust_left} , {robust_right}")
216+
logger.info(f"epsilon = {epsilon}")
217+
218+
end_time = time.time()
219+
elapsed_time = end_time - start_time
220+
hours, remainder = divmod(elapsed_time, 3600)
221+
minutes, seconds = divmod(remainder, 60)
222+
print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
223+
logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds")
224+
logger.info(f"&" * 150)
225+

0 commit comments

Comments
 (0)