From e91d9c0f2f9fc62747786e1dd03cf4da2ce069e8 Mon Sep 17 00:00:00 2001 From: Stefan Denner Date: Fri, 12 Apr 2024 15:07:47 +0200 Subject: [PATCH] Relaxed strictness when loading weights The loading of model weights fails due to some transformer version mismatch. Relaxing the strictness when loading the weights fixes this issue. https://github.com/RyanWangZf/MedCLIP/issues/37 --- medclip/modeling_medclip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/medclip/modeling_medclip.py b/medclip/modeling_medclip.py index 723d820..4268552 100644 --- a/medclip/modeling_medclip.py +++ b/medclip/modeling_medclip.py @@ -145,7 +145,7 @@ def __init__(self, if checkpoint is not None: state_dict = torch.load(os.path.join(checkpoint, constants.WEIGHTS_NAME)) - self.load_state_dict(state_dict) + self.load_state_dict(state_dict, strict=False) print('load model weight from:', checkpoint) def from_pretrained(self, input_dir=None): @@ -182,7 +182,7 @@ def from_pretrained(self, input_dir=None): print('\n Download pretrained model from:', pretrained_url) state_dict = torch.load(os.path.join(input_dir, constants.WEIGHTS_NAME)) - self.load_state_dict(state_dict) + self.load_state_dict(state_dict, strict=False) print('load model weight from:', input_dir) def encode_text(self, input_ids=None, attention_mask=None):