|
| 1 | + |
| 2 | + |
| 3 | +import torch |
| 4 | +import os |
| 5 | +os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| 6 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 7 | +import logging |
| 8 | + |
| 9 | +from scipy.stats import mannwhitneyu |
| 10 | +from diffusers import StableDiffusionPipeline |
| 11 | +from torchmetrics.multimodal import CLIPScore |
| 12 | +import torchvision.transforms as transforms |
| 13 | +import Levenshtein |
| 14 | +import numpy as np |
| 15 | +import time |
| 16 | +import pandas as pd |
| 17 | +torch.cuda.empty_cache() |
| 18 | +from config import ( |
| 19 | + e_threshold, origin_prompt_path, sigma, data_path, |
| 20 | + num_inference_steps, num_batch, batch_size, |
| 21 | + model_id, ori_prompt, stop_early |
| 22 | +) |
| 23 | + |
| 24 | +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
| 25 | +pipe = pipe.to(device) |
| 26 | + |
| 27 | +transform = transforms.Compose([transforms.ToTensor()]) |
| 28 | +metric = CLIPScore().to(device) |
| 29 | +def calculate_text_image_distance(text, image): |
| 30 | + # print(transform(image).shape) # torch.Size([1, 512, 512]) |
| 31 | + img = transform(image)*255 |
| 32 | + score = metric(img.to(device), text) |
| 33 | + return score.detach().cpu().numpy().item() |
| 34 | + |
| 35 | +def setup_logger(file_name): |
| 36 | + logger = logging.getLogger(file_name) |
| 37 | + logger.setLevel(logging.INFO) |
| 38 | + |
| 39 | + handler = logging.FileHandler(file_name) |
| 40 | + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S') |
| 41 | + handler.setFormatter(formatter) |
| 42 | + |
| 43 | + logger.addHandler(handler) |
| 44 | + return logger |
| 45 | + |
| 46 | +def cal_loss(ori_loss, disturb_prompt, ori_prompt): |
| 47 | + global stop_early |
| 48 | + stop_early = 0 |
| 49 | + print("dis_prompt", disturb_prompt) |
| 50 | + print("ori_prompt", ori_prompt) |
| 51 | + logger.info(f"dis_prompt: {disturb_prompt}") |
| 52 | + logger.info(f"ori_prompt: {ori_prompt}") |
| 53 | + logger.info(f"@" * 20) |
| 54 | + dis_interim_loss = [] |
| 55 | + for i in range(num_batch): # 5 |
| 56 | + ori_interim_loss = ori_loss[0: batch_size*(i+1)] |
| 57 | + generator = torch.Generator(device).manual_seed(1023+i) |
| 58 | + images = pipe([disturb_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator) |
| 59 | + for j in range(batch_size): # 5 |
| 60 | + dis_interim_loss.append(calculate_text_image_distance(ori_prompt, images.images[j])) |
| 61 | + |
| 62 | + _, p_val = mannwhitneyu(dis_interim_loss, ori_interim_loss) |
| 63 | + logger.info(f"p_val: {p_val}") |
| 64 | + if p_val <= 0.0158: |
| 65 | + if i <= 3: |
| 66 | + stop_early = 1 |
| 67 | + logger.info(f"Reject interim num : {i+1}") |
| 68 | + return 0, dis_interim_loss |
| 69 | + logger.info(f"@" * 20) |
| 70 | + logger.info(f"Accept interim num : {num_batch}") |
| 71 | + return 1, dis_interim_loss |
| 72 | + |
| 73 | + |
| 74 | +def get_AE(sample_data): |
| 75 | + import random |
| 76 | + random.seed(42) |
| 77 | + strings = [line.split(':')[0].strip() for line in sample_data[1:]] |
| 78 | + sampled_strings = random.sample(strings, len(strings)) |
| 79 | + return sampled_strings |
| 80 | + |
| 81 | +def calculate_R(E_n, n): |
| 82 | + import math |
| 83 | + robust, epsilon = 0, 0 |
| 84 | + epsilon = math.sqrt( (0.6 * math.log(math.log(n, 1.1) + 1, 10) + (1.8 ** -1) * math.log(24/sigma, 10)) / n ) |
| 85 | + robust = (E_n - epsilon) / n |
| 86 | + return robust, epsilon |
| 87 | + |
| 88 | +def get_origin_prompt(origin_prompt_path): |
| 89 | + origin_prompt = {} |
| 90 | + i = 1 |
| 91 | + with open(origin_prompt_path,'r') as file: |
| 92 | + for line in file: |
| 93 | + origin_prompt[i] = line.strip() |
| 94 | + i += 1 |
| 95 | + return origin_prompt |
| 96 | + |
| 97 | +if __name__ == "__main__": |
| 98 | + start_time = time.time() |
| 99 | + |
| 100 | + origin_prompts = get_origin_prompt(origin_prompt_path) |
| 101 | + |
| 102 | + for index, ori_prompt in origin_prompts.items(): |
| 103 | + logger = setup_logger(f"adaptive_log/log_{index}.log") |
| 104 | + logger.info(f"num_inference_steps: {num_inference_steps}") |
| 105 | + logger.info(f"num_batch: {num_batch}") |
| 106 | + logger.info(f"batch_size: {batch_size}") |
| 107 | + df = pd.read_csv(f"./generate_AE/char_AE/result_{index}.csv") |
| 108 | + logger.info(f"./generate_AE/char_AE/result_{index}.csv") |
| 109 | + efficient_n = 0 |
| 110 | + if index <=3: |
| 111 | + pass |
| 112 | + else: |
| 113 | + logger.info(f"ori_prompt: {ori_prompt}") |
| 114 | + ori_loss = [] |
| 115 | + for i in range(num_batch): |
| 116 | + generator = torch.Generator(device).manual_seed(1023+i) |
| 117 | + images = pipe([ori_prompt] * batch_size, num_inference_steps = num_inference_steps, generator = generator) |
| 118 | + for j in range(batch_size): |
| 119 | + ori_loss.append(calculate_text_image_distance(ori_prompt, images.images[j])) |
| 120 | + logger.info(f"ori_loss: {len(ori_loss)} {ori_loss}") |
| 121 | + logger.info(f"*" * 120) |
| 122 | + for id in range(1, 5): |
| 123 | + efficient_n = 0 |
| 124 | + E_n, n = 0, 0 |
| 125 | + L_distance, AdvSt2i = [], [] |
| 126 | + robust_re, epsilon_re = [], [] |
| 127 | + prompt_2 = [f"{index}A mysterious and magical forest with glowing mushrooms and hidden creatures.", "Serene meadow with wildflowers swaying in the gentle breeze under a clear blue sky."] |
| 128 | + logger.info(f"disturb rate: {id}") |
| 129 | + # logger.info(f"disturb_num: {sample_data[0]}") |
| 130 | + # prompt_2 = get_AE(sample_data) |
| 131 | + for i, disturb_prompt in enumerate(prompt_2): |
| 132 | + n = i + 1 |
| 133 | + L_distance.append(Levenshtein.distance(ori_prompt, disturb_prompt)) |
| 134 | + whether_robust, dis_loss = cal_loss(ori_loss, disturb_prompt, ori_prompt) |
| 135 | + if whether_robust: |
| 136 | + E_n += 1 |
| 137 | + if stop_early: |
| 138 | + efficient_n += 1 |
| 139 | + AdvSt2i.append(sum(dis_loss) / len(dis_loss)) |
| 140 | + robust, epsilon = calculate_R(E_n, n) |
| 141 | + robust_re.append(robust) |
| 142 | + epsilon_re.append(epsilon) |
| 143 | + logger.info(f"stop_early: {efficient_n}") |
| 144 | + logger.info(f"E_n: {E_n}") |
| 145 | + logger.info(f"n: {n}") |
| 146 | + logger.info(f"robust reach: {robust}") |
| 147 | + logger.info(f"epsilon reach: {epsilon}") |
| 148 | + if epsilon <= e_threshold: # stop condition |
| 149 | + break |
| 150 | + print("*" * 120) |
| 151 | + logger.info(f"*" * 120) |
| 152 | + print("*" * 120) |
| 153 | + logger.info(f"*" * 120) |
| 154 | + logger.info(f"robust = {robust_re}") |
| 155 | + logger.info(f"epsilon = {epsilon_re}") |
| 156 | + logger.info(f"stop_early: {efficient_n}") |
| 157 | + logger.info(f"E_n = {E_n}") |
| 158 | + logger.info(f"n = {n}") |
| 159 | + logger.info(f"AdvSt2i = {round(np.mean(AdvSt2i), 2)}") |
| 160 | + logger.info(f"OriSt2i = {round(np.mean(ori_loss), 2)}") |
| 161 | + logger.info(f"Levenshtein = {round(np.mean(L_distance), 2)}") |
| 162 | + logger.info(f"robust = {robust}") |
| 163 | + logger.info(f"epsilon = {epsilon}") |
| 164 | + |
| 165 | + |
| 166 | + |
| 167 | + end_time = time.time() |
| 168 | + elapsed_time = end_time - start_time |
| 169 | + hours, remainder = divmod(elapsed_time, 3600) |
| 170 | + minutes, seconds = divmod(remainder, 60) |
| 171 | + print(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds") |
| 172 | + logger.info(f"time cost: {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds") |
| 173 | + logger.info(f"&" * 150) |
| 174 | + |
| 175 | + |
| 176 | + |
| 177 | + |
| 178 | + |
0 commit comments