-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpoison_llava.py
executable file
·381 lines (297 loc) · 16.3 KB
/
poison_llava.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import sys
import argparse
import os
import gc
import json
import shutil
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torchvision import transforms
import torch.optim as optim
from torchvision.utils import save_image
from PIL import Image
import copy
from torchvision.transforms.functional import InterpolationMode
# diff augmentation
# import kornia
from augmentation_zoo import *
# llava
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
def parse_args():
parser = argparse.ArgumentParser(description="Poisoning")
parser.add_argument("--task_data_pth", default='data/task_data/Biden_base_Trump_target', help='task_data_pth folder contains base_train and target_train folders for constructing poison images')
parser.add_argument("--poison_save_pth", default='data/poisons/llava/Biden_base_Trump_target', help='Output path for saving pure poison images & original captions')
parser.add_argument("--iter_attack", type=int, default=4000)
parser.add_argument("--lr_attack", type=float, default=0.2)
parser.add_argument("--diff_aug_specify", type=str, default=None, help='if None, using the default diff_aug of the VLM')
parser.add_argument("--batch_size", type=int, default=60, help='batch size for running the PGD attack. Modify it according to your GPU memory')
args = parser.parse_args()
if args.diff_aug_specify == "None":
args.diff_aug_specify = None
return args
############ model-specific ############
def get_image_encoder_llava():
'''
Return: the image encoder, image processor and the data augmention used during training
image_processor is only for sanity check in test_attack_efficacy()
diff_aug will be used in crafting adversarial examples
'''
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
vision_model = copy.deepcopy(model.model.vision_tower); vision_model.eval()
# In llava, the forward function of CLIP is wrapped with torch.no_grad, which we get rid of below
image_encoder_ = vision_model.forward.__wrapped__
image_encoder = lambda x: image_encoder_(vision_model, x)
# delete the model (including LLM) to save memory
del model
gc.collect(); torch.cuda.empty_cache()
diff_aug = None
img_size = 336
return image_encoder, image_processor, diff_aug, img_size
############ model-agnostic ############
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
def load_image(image_path, show_image=False):
img = Image.open(image_path).convert('RGB')
if show_image:
plt.imshow(img)
plt.show()
return img
def load_image_tensors(task_data_pth,img_size):
'''
Input:
task_data_pth needs to contain two subfolders: base_train and target_train;
task_data_pth/base_train needs to contain cap.json caption file
img_size is the size of image for the VLM model. used for resizing.
Output: image tensors from base_train, target_train
'''
with open(os.path.join(task_data_pth,'base_train','cap.json')) as file:
base_train_cap = json.load(file)
num_total = len(base_train_cap['annotations'])
images_base = []
images_target = []
resize_fn = transforms.Resize(
(img_size, img_size), interpolation=InterpolationMode.BICUBIC
)
for i in range(num_total):
image_id = base_train_cap['annotations'][i]['image_id']
image_base_pth = os.path.join(task_data_pth, 'base_train', f'{image_id}.png')
image_target_pth = os.path.join(task_data_pth, 'target_train', f'{image_id}.png')
images_base.append(transforms.ToTensor()(resize_fn(load_image(image_base_pth))).unsqueeze(0))
images_target.append(transforms.ToTensor()(resize_fn(load_image(image_target_pth))).unsqueeze(0))
images_base = torch.cat(images_base, axis=0)
images_target = torch.cat(images_target, axis=0)
print(f'Finishing loading {num_total} pairs of base and target images for poisoning, size={images_base.size()}')
return images_base, images_target
class PairedImageDataset(torch.utils.data.Dataset):
def __init__(self, images_base, images_target):
'''
both input image are tensors with (num_example, 3, h, w)
This dataset be used to construct dataloader for batching
'''
super().__init__()
assert images_base.shape[0] == images_target.shape[0]
self.images_base = images_base
self.images_target = images_target
def __len__(self):
return self.images_base.shape[0]
def __getitem__(self, index):
return self.images_base[index], self.images_target[index]
def embedding_attack_Linf(image_encoder, image_base, image_victim, emb_dist, \
iters=100, lr=1/255, eps=8/255, diff_aug=None, resume_X_adv=None):
'''
optimizing x_adv to minimize emb_dist( img_embed of x_adv, img_embed of image_victim ) within Lp constraint using PGD
image_encoder: the image embedding function (e.g. CLIP, EVA)
image_base, image_victim: images BEFORE normalization, between [0,1]
emb_dist: the distance metrics for vision embedding (such as L2): take a batch of bs image pairs as input, \
and output EACH of pair-wise distances of the whole batch (size = [bs])
eps: for Lp constraint
lr: the step size. The update is grad.sign * lr
diff_aug: using differentiable augmentation, e.g. RandomResizeCrop
resume_X_adv: None or an initialization for X_adv
return: X_adv between [0,1]
'''
assert len(image_base.size()) == len(image_victim.size()) and len(image_base.size()) == 4, 'image size length should be 4'
assert image_base.size(0) == image_victim.size(0), 'image_base and image_victim contain different number of images'
bs = image_base.size(0)
device = image_base.device
with torch.no_grad():
embedding_targets = image_encoder(normalize(image_victim))
X_adv = image_base.clone().detach() + (torch.rand(*image_base.shape)*2*eps-eps).to(device)
if resume_X_adv is not None:
print('Resuming from a given X_adv')
X_adv = resume_X_adv.clone().detach()
X_adv.data = X_adv.data.clamp(0,1)
X_adv.requires_grad_(True)
optimizer = optim.SGD([X_adv], lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(iters*0.5)], gamma=0.5)
loss_best = 1e8 * torch.ones(bs).to(device)
X_adv_best = resume_X_adv.clone().detach() if resume_X_adv is not None else torch.rand(*image_base.shape).to(device)
for i in tqdm(range(iters)):
# for i in range(iters):
if diff_aug is not None:
# NOTE: using differentiable randomresizedcrop here
X_adv_input_to_model = normalize(diff_aug(X_adv))
else:
X_adv_input_to_model = normalize(X_adv)
loss = emb_dist(image_encoder(X_adv_input_to_model), embedding_targets) # length = bs
if i% max(int(iters/1000),1) == 0:
if (loss < loss_best).sum()>0:
index = torch.where(loss < loss_best)[0]
loss_best[index] = loss.clone().detach()[index].to(loss_best[index].dtype)
X_adv_best[index] = X_adv.clone().detach()[index]
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
if i% max(int(iters/20),1) == 0:
print('Iter :{} loss:{:.4f}, lr * 255:{:.4f}'.format(i,loss.item()/bs, scheduler.get_last_lr()[0]*255))
# Linf sign update
X_adv.grad = torch.sign(X_adv.grad)
optimizer.step()
scheduler.step()
X_adv.data = torch.minimum(torch.maximum(X_adv, image_base - eps), image_base + eps)
X_adv.data = X_adv.data.clamp(0,1)
X_adv.grad = None
if torch.isnan(loss):
print('Encounter nan loss at iteration {}'.format(i))
break
with torch.no_grad():
if diff_aug:
print('Using diff_aug')
X_adv_best_input_to_model = normalize(diff_aug(X_adv_best))
else:
print('Not using diff_aug')
X_adv_best_input_to_model = normalize(X_adv_best)
loss = emb_dist(image_encoder(X_adv_best_input_to_model), embedding_targets)
# print('Best Total loss vector:{}'.format(loss))
print('Best Total loss:{:.4f}'.format(loss.mean().item()))
return X_adv_best, loss.detach()
def L2_norm(a,b):
'''
a,b: batched image/representation tensors
'''
assert a.size(0) == b.size(0), 'two inputs contain different number of examples'
bs = a.size(0)
dist_vec = (a-b).view(bs,-1).norm(p=2, dim=1)
return dist_vec
def save_poison_data(images_to_save, caption_pth, save_path):
'''
Save the pure poison data set as the same folder format as cc_sbu_align
Input:
images_to_save: a batch of image tensors (perturbed base_train images)
caption_pth: json file path of captions for the unpoisoned images (base_train captions)
save_path: path for saving poisoned images and original captions.
need to save to png, not jpeg.
'''
assert len(images_to_save.size()) == 4, 'images_to_save should be a batch of image tensors, 4 dimension'
with open(caption_pth) as file:
cap = json.load(file)
num_total = len(cap['annotations'])
assert images_to_save.size(0) == num_total, 'numbers of images and captions are different'
# save image using the original image_id
for i in range(num_total):
image_id = cap['annotations'][i]['image_id']
img_pth = os.path.join(save_path, 'image', '{}.png'.format(image_id))
save_image(images_to_save[i],img_pth)
# rename to .jpg
img_pth_jpg = os.path.join(save_path, 'image', '{}.jpg'.format(image_id))
os.rename(img_pth,img_pth_jpg)
# copy the json file
shutil.copyfile(caption_pth, os.path.join(save_path,'cap.json'))
print('Finished saving the pure poison data to {}'.format(save_path))
def test_attack_efficacy(image_encoder, image_processor, task_data_pth, poison_data_pth, img_size, sample_num=20):
'''
Sanity check after crafting poison model
Reload image_base, image_target and image_poison from jpg
Go through image processor, and check the relative distance in the image embedding space
sample_num: only compute statistics for the first sample_num image triples and then take the average
Output: will print averaged latent_dist(image_base,image_target) and latent_dist(image_poison,image_target)
also output the pixel distance between base and poison images
NOTE: image_processor includes data augmentation. However, when using differantial jpeg during creating poison image,
the image_processor will not include jpeg operation.
'''
# RGB image
images_base, images_target = [], []
images_poison = []
# load data
with open(os.path.join(poison_data_pth,'cap.json')) as file:
cap = json.load(file)
num_total = len(cap['annotations'])
for i in range(num_total):
image_id = cap['annotations'][i]['image_id']
image_base_pth = os.path.join(task_data_pth, 'base_train', f'{image_id}.png')
image_target_pth = os.path.join(task_data_pth, 'target_train', f'{image_id}.png')
image_poison_pth = os.path.join(poison_data_pth, 'image', f'{image_id}.jpg')
images_base.append((load_image(image_base_pth)))
images_target.append((load_image(image_target_pth)))
images_poison.append((load_image(image_poison_pth)))
if i >= sample_num:
break
resize_fn = transforms.Resize(
(img_size, img_size), interpolation=InterpolationMode.BICUBIC
)
# compute embedding distance
dist_base_target_list = []
dist_poison_target_list = []
pixel_dist_base_poison = [] # Linf distance in pixel space
for i in range(len(images_base)):
image_base, image_target, image_poison = images_base[i], images_target[i], images_poison[i]
emb_base = image_encoder( torch.from_numpy(image_processor(image_base)['pixel_values'][0]).cuda().unsqueeze(0) )
emb_target = image_encoder( torch.from_numpy(image_processor(image_target)['pixel_values'][0]).cuda().unsqueeze(0) )
emb_poison = image_encoder( torch.from_numpy(image_processor(image_poison)['pixel_values'][0]).cuda().unsqueeze(0) )
dist_base_target_list.append( (emb_base - emb_target).norm().item() )
dist_poison_target_list.append( (emb_poison - emb_target).norm().item() )
pixel_dist_base_poison.append( torch.norm(transforms.ToTensor()(resize_fn(image_base)) - transforms.ToTensor()(image_poison), float('inf')) )
dist_base_target_list = torch.Tensor(dist_base_target_list)
dist_poison_target_list = torch.Tensor(dist_poison_target_list)
pixel_dist_base_poison = torch.Tensor(pixel_dist_base_poison)
print('\n Sanity check of the optimization, considering image loading and image processor')
print(f'>>> ratio betwen dist_base_target and dist_poison_target:\n{dist_base_target_list/dist_poison_target_list}')
print(f'ratio mean: {(dist_base_target_list/dist_poison_target_list).mean()}')
print(f'>>> Max Linf pixel distance * 255 between base and poison: {(pixel_dist_base_poison*255).max()}')
return
if __name__ == "__main__":
args = parse_args()
if os.path.exists(args.poison_save_pth):
raise ValueError('{} already exists for saving pure poisoned data. Delete it or choose another path!'.format(args.poison_save_pth))
else:
os.makedirs(os.path.join(args.poison_save_pth,'image'))
print(f'Poisong images will be saved to {args.poison_save_pth}')
print(f'iter_attack {args.iter_attack}, lr_attack {args.lr_attack}')
###### model preparation ######
image_encoder, image_processor, diff_aug, img_size = get_image_encoder_llava()
if args.diff_aug_specify is not None:
diff_aug = get_image_augmentation(augmentation_name=args.diff_aug_specify, image_size=img_size)
else:
print('Using default diff_aug')
###### data preparation ######
images_base, images_target = load_image_tensors(args.task_data_pth,img_size)
dataset_pair = PairedImageDataset(images_base=images_base, images_target=images_target)
dataloader_pair = torch.utils.data.DataLoader(dataset_pair, batch_size=args.batch_size, shuffle=False)
###### Running attack optimization ######
X_adv_list = []
loss_attack_list = []
for i, (image_base, image_victim) in enumerate(dataloader_pair):
# if i == 1:
# break
print('batch_id = ',i)
image_base, image_victim = image_base.cuda(), image_victim.cuda()
X_adv, loss_attack = embedding_attack_Linf(image_encoder=image_encoder, image_base=image_base, image_victim=image_victim, emb_dist=L2_norm, \
iters=args.iter_attack, lr=args.lr_attack/255, eps=8/255, diff_aug=diff_aug, resume_X_adv=None)
X_adv_list.append(X_adv)
loss_attack_list.append(loss_attack)
X_adv = torch.cat(X_adv_list,axis=0)
loss_attack = torch.cat(loss_attack_list,dim=0)
###### Saving poison data ######
save_poison_data(images_to_save=X_adv.cpu(), caption_pth=os.path.join(args.task_data_pth,'base_train','cap.json'), \
save_path=args.poison_save_pth)
# sanity check (taking into consideration of loading images and image processor)
test_attack_efficacy(image_encoder=image_encoder, image_processor=image_processor, \
task_data_pth=args.task_data_pth, poison_data_pth=args.poison_save_pth, img_size=img_size, sample_num=50)
print(f'Poisong images are saved to {args.poison_save_pth}')