|
| 1 | +import os |
| 2 | +os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| 3 | +import torch |
| 4 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 5 | +import logging |
| 6 | +from scipy.stats import mannwhitneyu |
| 7 | +from diffusers import StableDiffusionPipeline |
| 8 | +from torchmetrics.multimodal import CLIPScore |
| 9 | +import torchvision.transforms as transforms |
| 10 | +import Levenshtein |
| 11 | +from scipy.stats import shapiro |
| 12 | +import numpy as np |
| 13 | +import time |
| 14 | +from scipy import stats |
| 15 | +import pandas as pd |
| 16 | +torch.cuda.empty_cache() |
| 17 | +from config import ( |
| 18 | + e_threshold, origin_prompt_path, sigma, |
| 19 | + num_inference_steps, num_batch, batch_size, |
| 20 | + model_id |
| 21 | +) |
| 22 | +import random |
| 23 | +# random.seed(42) |
| 24 | + |
| 25 | +def setup_logger(file_name): |
| 26 | + logger = logging.getLogger(file_name) |
| 27 | + logger.setLevel(logging.INFO) |
| 28 | + |
| 29 | + handler = logging.FileHandler(file_name) |
| 30 | + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S') |
| 31 | + handler.setFormatter(formatter) |
| 32 | + |
| 33 | + logger.addHandler(handler) |
| 34 | + return logger |
| 35 | + |
| 36 | +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
| 37 | +# pipe = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=torch.float16) |
| 38 | +pipe = pipe.to(device) |
| 39 | + |
| 40 | +transform = transforms.Compose([transforms.ToTensor()]) |
| 41 | +metric = CLIPScore().to(device) |
| 42 | +def calculate_text_image_distance(text, image): |
| 43 | + # print(transform(image).shape) # torch.Size([1, 512, 512]) |
| 44 | + img = transform(image)*255 |
| 45 | + score = metric(img.to(device), text) |
| 46 | + return score.detach().cpu().numpy().item() |
| 47 | + |
| 48 | +def calculate_u_zValue(data1, data2): |
| 49 | + U_statistic, p_val = mannwhitneyu(data1, data2, alternative='greater') |
| 50 | + n1, n2 = len(data1), len(data2) |
| 51 | + mean_U = n1 * n2 / 2 |
| 52 | + std_U = np.sqrt(n1 * n2 * (n1 + n2 + 1) / 12) |
| 53 | + z_value = (U_statistic - mean_U) / std_U |
| 54 | + return p_val, z_value |
| 55 | + |
| 56 | +def calculate_t_zValue(data1, data2): |
| 57 | + t_statistic, p_val = stats.ttest_ind(data1, data2) |
| 58 | + mean1, mean2 = np.mean(data1), np.mean(data2) |
| 59 | + std1, std2 = np.std(data1, ddof=1), np.std(data2, ddof=1) |
| 60 | + n1, n2 = len(data1), len(data2) |
| 61 | + z_value = (mean1 - mean2) / np.sqrt((std1**2 / n1) + (std2**2 / n2)) |
| 62 | + return p_val, z_value |
| 63 | + |
| 64 | + |
| 65 | +# def cal_loss(ori_loss, disturb_prompt, ori_prompt): |
| 66 | +# print("dis_prompt", disturb_prompt) |
| 67 | +# print("ori_prompt", ori_prompt) |
| 68 | +# logger.info(f"dis_prompt: {disturb_prompt}") |
| 69 | +# logger.info(f"ori_prompt: {ori_prompt}") |
| 70 | +# logger.info(f"--" * 20) |
| 71 | +# whether_robust, stop_early, futility = 1, 0, 0 |
| 72 | +# dis_interim_loss = [] |
| 73 | +# for i in range(num_batch): # 5 |
| 74 | +# ori_interim_loss = ori_loss[0: batch_size*(i+1)] |
| 75 | +# generator = torch.Generator(device).manual_seed(1023+i) |
| 76 | +# images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator) |
| 77 | +# for j in range(batch_size): # 12 |
| 78 | +# dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j])) |
| 79 | +# # U-test |
| 80 | +# # _, p_val = mannwhitneyu(dis_interim_loss, ori_interim_loss, alternative = 'less') |
| 81 | +# logger.info(f"dis_interim_loss: {len(dis_interim_loss)}; {dis_interim_loss}") |
| 82 | +# logger.info(f"ori_interim_loss: {len(ori_interim_loss)}; {ori_interim_loss}") |
| 83 | +# logger.info(f"--" * 20) |
| 84 | +# return whether_robust |
| 85 | + |
| 86 | + |
| 87 | +def cal_loss(ori_loss, disturb_prompt, ori_prompt): |
| 88 | + print("dis_prompt", disturb_prompt) |
| 89 | + print("ori_prompt", ori_prompt) |
| 90 | + logger.info(f"dis_prompt: {disturb_prompt}") |
| 91 | + logger.info(f"ori_prompt: {ori_prompt}") |
| 92 | + logger.info(f"--" * 20) |
| 93 | + dis_interim_loss = [] |
| 94 | + alpha = [0.0148, 0.0262, 0.0354, 0.0432, 0.05] |
| 95 | + futility_boundary = [-0.145, 0.511, 1.027, 1.497, float('inf')] |
| 96 | + efficacy_boundary = [2.176, 2.144, 2.113, 2.090, 2.071] # critical value when Z-score > critical value, the reject the H0 |
| 97 | + for i in range(num_batch): # 5 |
| 98 | + ori_interim_loss = ori_loss[0: batch_size*(i+1)] |
| 99 | + generator = torch.Generator(device).manual_seed(1023+i) |
| 100 | + images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator) |
| 101 | + for j in range(batch_size): # 5 |
| 102 | + dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j])) |
| 103 | + |
| 104 | + logger.info(f"dis_interim_loss: {len(dis_interim_loss)}; {dis_interim_loss}") |
| 105 | + logger.info(f"ori_interim_loss: {len(ori_interim_loss)}; {ori_interim_loss}") |
| 106 | + _, p_1 = shapiro(ori_interim_loss[0: 12*(i+1)]) |
| 107 | + _, p_2 = shapiro(dis_interim_loss[0: 12*(i+1)]) |
| 108 | + if p_1 > 0.05 and p_2 > 0.05: # normal distr |
| 109 | + p_val, z_val = calculate_t_zValue(ori_interim_loss[0: 12*(i+1)], dis_interim_loss[0: 12*(i+1)]) |
| 110 | + else: |
| 111 | + p_val, z_val = calculate_u_zValue(ori_interim_loss[0: 12*(i+1)], dis_interim_loss[0: 12*(i+1)]) |
| 112 | + logger.info(f"p_val, z_val: {p_val} {z_val}") |
| 113 | + |
| 114 | + if z_val >= efficacy_boundary[i]: |
| 115 | + return 0 |
| 116 | + |
| 117 | + if z_val <= futility_boundary[i]: |
| 118 | + return 1 |
| 119 | + |
| 120 | + if i == 4: |
| 121 | + if p_val > alpha[i]: |
| 122 | + return 1 |
| 123 | + else: |
| 124 | + return 0 |
| 125 | + logger.info(f"--" * 20) |
| 126 | + return 1 |
| 127 | + |
| 128 | + |
| 129 | +def get_AE(sample_data): |
| 130 | + import random |
| 131 | + random.seed(42) |
| 132 | + strings = [line.split(':')[0].strip() for line in sample_data[1:]] |
| 133 | + # sampled_strings = random.sample(strings, len(strings)) |
| 134 | + sampled_strings = random.choices(strings, k=1) |
| 135 | + return sampled_strings |
| 136 | + |
| 137 | +def calculate_R(E_n, n): |
| 138 | + import math |
| 139 | + robust_left, robust_right, epsilon = 0, 0, 0 |
| 140 | + epsilon = math.sqrt( (0.6 * math.log(math.log(n, 1.1) + 1, 10) + (1.8 ** -1) * math.log(24/sigma, 10)) / n ) |
| 141 | + robust_left = E_n/n - epsilon |
| 142 | + robust_right = E_n/n + epsilon |
| 143 | + return robust_left, robust_right, epsilon |
| 144 | + |
| 145 | +def get_origin_prompt(origin_prompt_path): |
| 146 | + origin_prompt = {} |
| 147 | + i = 1 |
| 148 | + with open(origin_prompt_path,'r') as file: |
| 149 | + for line in file: |
| 150 | + origin_prompt[i] = line.strip() |
| 151 | + i += 1 |
| 152 | + return origin_prompt |
| 153 | + |
| 154 | +def defence_gramfomer(influent_sentence): |
| 155 | + from gramformer import Gramformer |
| 156 | + gf = Gramformer(models = 1, use_gpu=False) # 1=corrector, 2=detector |
| 157 | + corrected_sentences = gf.correct(influent_sentence, max_candidates=1) |
| 158 | + for corrected_sentence in corrected_sentences: |
| 159 | + return corrected_sentence |
| 160 | + return "" |
| 161 | + |
| 162 | +# def defence_spellchecker(influent_sentence): |
| 163 | +# from spellchecker import SpellChecker |
| 164 | +# import pandas as pd |
| 165 | +# spell = SpellChecker() |
| 166 | +# str1 = influent_sentence |
| 167 | +# str2 = str1.split() |
| 168 | +# misspelled = spell.unknown(str1.split()) # {'qducks.'} |
| 169 | +# corrected_sentence = " ".join(spell.correction(word) if word in misspelled else word for word in str2) |
| 170 | +# return corrected_sentence + '.' |
| 171 | +def defence_spellchecker(influent_sentence): |
| 172 | + from spellchecker import SpellChecker |
| 173 | + import pandas as pd |
| 174 | + spell = SpellChecker() |
| 175 | + str1 = influent_sentence |
| 176 | + str2 = str1.split() |
| 177 | + misspelled = spell.unknown(str1.split()) # {'qducks.'} |
| 178 | + corrected_sentence = " ".join(spell.correction(word) if word in misspelled else word for word in str2 if spell.correction(word) is not None) |
| 179 | + return corrected_sentence + '.' |
| 180 | + |
| 181 | +def defence_autocorrect(influent_sentence): |
| 182 | + from autocorrect import Speller |
| 183 | + spell = Speller(lang='en') |
| 184 | + text = influent_sentence |
| 185 | + corrected_text = ' '.join([spell(word) for word in text.split()]) |
| 186 | + return corrected_text |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + start_time = time.time() |
| 191 | + origin_prompts = get_origin_prompt(origin_prompt_path) |
| 192 | + for index, ori_prompt in origin_prompts.items(): |
| 193 | + if index == 19: |
| 194 | + AEdata_path = f"./generate_AE/coco/char_AE/result_{index}.csv" |
| 195 | + logger = setup_logger(f"adaptive_log/coco_update/10_rate/defence/spellchecker/log_char_{index}.log") |
| 196 | + logger.info(f"sigma: {sigma}") |
| 197 | + logger.info(f"num_inference_steps: {num_inference_steps}") |
| 198 | + logger.info(f"num_batch: {num_batch}") |
| 199 | + logger.info(f"batch_size: {batch_size}") |
| 200 | + logger.info(AEdata_path) |
| 201 | + logger.info(f"ori_prompt: {ori_prompt}") |
| 202 | + df = pd.read_csv(AEdata_path) |
| 203 | + ori_loss = [] |
| 204 | + for i in range(num_batch): |
| 205 | + generator = torch.Generator(device).manual_seed(1023+i) |
| 206 | + images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator) |
| 207 | + for j in range(batch_size): |
| 208 | + ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j])) |
| 209 | + logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}") |
| 210 | + logger.info(f"*" * 120) |
| 211 | + for id in range(1, 2): # focus on 10% perturb rate |
| 212 | + efficient_n = 0 |
| 213 | + Non_AE, n = 0, 0 |
| 214 | + L_distance, AdvSt2i = [], [] |
| 215 | + robust_re, epsilon_re = [], [] |
| 216 | + sample_data = list(df[f"Column {id}"].dropna()) |
| 217 | + strings = [line.split(':')[0].strip() for line in sample_data[1:]] |
| 218 | + logger.info(f"disturb rate: {id}") |
| 219 | + logger.info(f"disturb_num: {sample_data[0]}") |
| 220 | + n = 1 |
| 221 | + epsilon = 1000 |
| 222 | + for count in range(400): # while epsilon > e_threshold: |
| 223 | + selected = random.choices(strings, k=1)[0] |
| 224 | + # disturb_prompt = defence_gramfomer(selected) |
| 225 | + disturb_prompt = defence_spellchecker(selected) |
| 226 | + # disturb_prompt = defence_autocorrect(selected) |
| 227 | + if disturb_prompt == ori_prompt: |
| 228 | + Non_AE += 1 |
| 229 | + logger.info(f"dis_prompt: {selected}") |
| 230 | + logger.info(f"ori_prompt: {ori_prompt}") |
| 231 | + logger.info(f"same") |
| 232 | + else: |
| 233 | + logger.info(f"unsame") |
| 234 | + logger.info(f"selected: {selected}") |
| 235 | + logger.info(f"revised: {disturb_prompt}") |
| 236 | + L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt)) |
| 237 | + # whether_robust, dis_loss, stop_early = cal_loss(ori_loss, disturb_prompt, ori_prompt) |
| 238 | + whether_robust = cal_loss(ori_loss, disturb_prompt, ori_prompt) |
| 239 | + Non_AE += 1 if whether_robust else 0 |
| 240 | + |
| 241 | + # efficient_n += 1 if stop_early else 0 |
| 242 | + # AdvSt2i.append(sum(dis_loss) / len(dis_loss)) |
| 243 | + |
| 244 | + robust_left, robust_right, epsilon = calculate_R(Non_AE, n) |
| 245 | + robust_re.append((robust_left, robust_right)) |
| 246 | + epsilon_re.append(epsilon) |
| 247 | + logger.info(f"stop_early: {efficient_n}") |
| 248 | + logger.info(f"Non_AE: {Non_AE}") |
| 249 | + logger.info(f"n: {n}") |
| 250 | + logger.info(f"robust reach: {robust_left} , {robust_right}") |
| 251 | + logger.info(f"epsilon reach: {epsilon}") |
| 252 | + print("*" * 120) |
| 253 | + logger.info(f"*" * 120) |
| 254 | + n += 1 |
| 255 | + print("*" * 120) |
| 256 | + logger.info(f"*" * 120) |
| 257 | + logger.info(f"robust = {robust_re}") |
| 258 | + logger.info(f"epsilon = {epsilon_re}") |
| 259 | + logger.info(f"stop_early = {efficient_n}") |
| 260 | + logger.info(f"E_n = {Non_AE}") |
| 261 | + logger.info(f"n = {n}") |
| 262 | + logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}") |
| 263 | + logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}") |
| 264 | + logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}") |
| 265 | + logger.info(f"robust = {robust_left} , {robust_right}") |
| 266 | + logger.info(f"epsilon = {epsilon}") |
| 267 | + |
| 268 | + end_time = time.time() |
| 269 | + elapsed_time = end_time - start_time |
| 270 | + hours, remainder = divmod(elapsed_time, 3600) |
| 271 | + minutes, seconds = divmod(remainder, 60) |
| 272 | + print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds") |
| 273 | + logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds") |
| 274 | + logger.info(f"&" * 150) |
| 275 | + |
| 276 | + |
0 commit comments