-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbiomedclip.py
31 lines (23 loc) · 1002 Bytes
/
biomedclip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import open_clip
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer as timm_ViT
from .lora import LoRA_ViT_timm
class BiomedCLIPViT_LoRA(nn.Module):
MODEL_TAG = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
def __init__(self, lora_rank=4):
super().__init__()
self.lora_rank = lora_rank
biomedclip = open_clip.create_model(self.MODEL_TAG)
# LoRA-tune the vision transformer
vit = biomedclip.visual.trunk
assert isinstance(vit, timm_ViT)
self.lora_vit = LoRA_ViT_timm(vit_model=vit, r=lora_rank)
# get features from the vision transformer
def forward(self, image):
B = image.shape[0]
# remove [CLS] token
feat = self.lora_vit.lora_vit.forward_features(image)[:, 1:] # [B, 196, 768]
feat = feat.reshape(B, -1, 14, 14) # [B, 768, 14, 14]
return feat
def biomedclip_lora(lora_rank):
return BiomedCLIPViT_LoRA(lora_rank=lora_rank)