From 13b744e806323aa3b5b05edbc34a14874b36961e Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Thu, 6 Feb 2025 13:06:14 -0500 Subject: [PATCH] feat(models): for similarity, support transformers model --- xaitk_saliency_demo/app/config.py | 10 +++++--- xaitk_saliency_demo/app/ml/models.py | 37 ++++++++++++++++++++++++++-- xaitk_saliency_demo/app/ml/xai.py | 23 ++++++++++++++++- 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/xaitk_saliency_demo/app/config.py b/xaitk_saliency_demo/app/config.py index 2125fd4..5f0ce5f 100644 --- a/xaitk_saliency_demo/app/config.py +++ b/xaitk_saliency_demo/app/config.py @@ -8,6 +8,10 @@ # Task => model "model_active": "SimilarityResNet50", "model_available": [ + { + "title": "microsoft/resnet-50", + "value": "transformers:similarity:microsoft/resnet-50", + }, {"title": "ResNet-50", "value": "SimilarityResNet50"}, {"title": "AlexNet", "value": "SimilarityAlexNet"}, {"title": "VGG-16", "value": "SimilarityVgg16"}, @@ -39,15 +43,15 @@ {"title": "Sliding Window Stack", "value": "SlidingWindowStack"}, ], # Task => model - "model_active": "transformers:microsoft/resnet-50", + "model_active": "transformers:classification:microsoft/resnet-50", "model_available": [ { "title": "microsoft/resnet-50", - "value": "transformers:microsoft/resnet-50", + "value": "transformers:classification:microsoft/resnet-50", }, { "title": "google/vit-base-patch16-224", - "value": "transformers:google/vit-base-patch16-224", + "value": "transformers:classification:google/vit-base-patch16-224", }, {"title": "ResNet-50", "value": "ClassificationResNet50"}, {"title": "AlexNet", "value": "ClassificationAlexNet"}, diff --git a/xaitk_saliency_demo/app/ml/models.py b/xaitk_saliency_demo/app/ml/models.py index 8a66e67..8a2df29 100644 --- a/xaitk_saliency_demo/app/ml/models.py +++ b/xaitk_saliency_demo/app/ml/models.py @@ -120,10 +120,18 @@ class TransformersModel(AbstractModel): def is_transformers_model(model_name: str) -> bool: return model_name.startswith(TransformersModel.TRANSFORMERS_PREFIX) + @staticmethod + def get_task(model_name: str) -> str: + return model_name.split(":")[1] + + @staticmethod + def get_hub_id(model_name: str) -> str: + return model_name.split(":")[-1] + def __init__(self, server, model_name, device=None): if device is None: device = DEVICE - hub_id = model_name[len(self.TRANSFORMERS_PREFIX) :] + hub_id = TransformersModel.get_hub_id(model_name) model = pipeline( model=hub_id, device=device, @@ -252,6 +260,25 @@ def __init__(self, server): ) +class SimilarityTransformers(TransformersModel, SimilarityRun): + def __init__(self, server, model_name): + super().__init__(server, model_name) + self._model.model.config.output_hidden_states = True + + def predict(self, input) -> np.ndarray: + image = Image.fromarray(input) + processed = self._model.image_processor(images=image, return_tensors="pt") + device = next(self._model.model.parameters()).device + processed = { + k: (v.to(device) if hasattr(v, "to") else v) for k, v in processed.items() + } + outputs = self._model.model(**processed) + + feature_descriptor = outputs.hidden_states[-1].mean(dim=1) + + return feature_descriptor[0].cpu().detach().numpy().flatten() + + # ----------------------------------------------------------------------------- # Detection # ----------------------------------------------------------------------------- @@ -349,7 +376,13 @@ def get_model(server, model_name): return MODEL_INSTANCES[server][model_name] if TransformersModel.is_transformers_model(model_name): - model = TransformersClassificationModel(server, model_name) + task = model_name.split(":")[1] + if task == "classification": + model = TransformersClassificationModel(server, model_name) + elif task == "similarity": + model = SimilarityTransformers(server, model_name) + else: + raise ValueError(f"Unknown transformers task: {task}") else: model = globals()[model_name](server) MODEL_INSTANCES[server][model_name] = model diff --git a/xaitk_saliency_demo/app/ml/xai.py b/xaitk_saliency_demo/app/ml/xai.py index 7a5462b..adb1faf 100644 --- a/xaitk_saliency_demo/app/ml/xai.py +++ b/xaitk_saliency_demo/app/ml/xai.py @@ -124,6 +124,22 @@ def get_config(self) -> Dict[str, Any]: return {} +class TransformersDescrModel(ImageDescriptorGenerator): + def __init__(self, model): + self.model = model + + @torch.no_grad() + def generate_arrays_from_images( + self, img_mat_iter: Iterable[np.ndarray] + ) -> Iterable[np.ndarray]: + for img in img_mat_iter: + yield self.model.predict(img).squeeze() + + def get_config(self) -> Dict[str, Any]: + # Required by a parent class. Will not be used in this context. + return {} + + # ----------------------------------------------------------------------------- class Saliency: def __init__(self, model, name, params): @@ -160,7 +176,12 @@ def run(self, input, *_): class SimilaritySaliency(Saliency): def run(self, reference, query): self._saliency.fill = FILL - sal = self._saliency(reference, [query], DescrModel(self._model)) + model = ( + DescrModel(self._model) + if not isinstance(self._model, TransformersModel) + else TransformersDescrModel(self._model) + ) + sal = self._saliency(reference, [query], model) return { "type": "similarity", "saliency": sal,