Skip to content

Commit

Permalink
feat(models): for similarity, support transformers model
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Feb 11, 2025
1 parent 6d3d927 commit 13b744e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
10 changes: 7 additions & 3 deletions xaitk_saliency_demo/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
# Task => model
"model_active": "SimilarityResNet50",
"model_available": [
{
"title": "microsoft/resnet-50",
"value": "transformers:similarity:microsoft/resnet-50",
},
{"title": "ResNet-50", "value": "SimilarityResNet50"},
{"title": "AlexNet", "value": "SimilarityAlexNet"},
{"title": "VGG-16", "value": "SimilarityVgg16"},
Expand Down Expand Up @@ -39,15 +43,15 @@
{"title": "Sliding Window Stack", "value": "SlidingWindowStack"},
],
# Task => model
"model_active": "transformers:microsoft/resnet-50",
"model_active": "transformers:classification:microsoft/resnet-50",
"model_available": [
{
"title": "microsoft/resnet-50",
"value": "transformers:microsoft/resnet-50",
"value": "transformers:classification:microsoft/resnet-50",
},
{
"title": "google/vit-base-patch16-224",
"value": "transformers:google/vit-base-patch16-224",
"value": "transformers:classification:google/vit-base-patch16-224",
},
{"title": "ResNet-50", "value": "ClassificationResNet50"},
{"title": "AlexNet", "value": "ClassificationAlexNet"},
Expand Down
37 changes: 35 additions & 2 deletions xaitk_saliency_demo/app/ml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,18 @@ class TransformersModel(AbstractModel):
def is_transformers_model(model_name: str) -> bool:
return model_name.startswith(TransformersModel.TRANSFORMERS_PREFIX)

@staticmethod
def get_task(model_name: str) -> str:
return model_name.split(":")[1]

@staticmethod
def get_hub_id(model_name: str) -> str:
return model_name.split(":")[-1]

def __init__(self, server, model_name, device=None):
if device is None:
device = DEVICE
hub_id = model_name[len(self.TRANSFORMERS_PREFIX) :]
hub_id = TransformersModel.get_hub_id(model_name)
model = pipeline(
model=hub_id,
device=device,
Expand Down Expand Up @@ -252,6 +260,25 @@ def __init__(self, server):
)


class SimilarityTransformers(TransformersModel, SimilarityRun):
def __init__(self, server, model_name):
super().__init__(server, model_name)
self._model.model.config.output_hidden_states = True

def predict(self, input) -> np.ndarray:
image = Image.fromarray(input)
processed = self._model.image_processor(images=image, return_tensors="pt")
device = next(self._model.model.parameters()).device
processed = {
k: (v.to(device) if hasattr(v, "to") else v) for k, v in processed.items()
}
outputs = self._model.model(**processed)

feature_descriptor = outputs.hidden_states[-1].mean(dim=1)

return feature_descriptor[0].cpu().detach().numpy().flatten()


# -----------------------------------------------------------------------------
# Detection
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -349,7 +376,13 @@ def get_model(server, model_name):
return MODEL_INSTANCES[server][model_name]

if TransformersModel.is_transformers_model(model_name):
model = TransformersClassificationModel(server, model_name)
task = model_name.split(":")[1]
if task == "classification":
model = TransformersClassificationModel(server, model_name)
elif task == "similarity":
model = SimilarityTransformers(server, model_name)
else:
raise ValueError(f"Unknown transformers task: {task}")
else:
model = globals()[model_name](server)
MODEL_INSTANCES[server][model_name] = model
Expand Down
23 changes: 22 additions & 1 deletion xaitk_saliency_demo/app/ml/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def get_config(self) -> Dict[str, Any]:
return {}


class TransformersDescrModel(ImageDescriptorGenerator):
def __init__(self, model):
self.model = model

@torch.no_grad()
def generate_arrays_from_images(
self, img_mat_iter: Iterable[np.ndarray]
) -> Iterable[np.ndarray]:
for img in img_mat_iter:
yield self.model.predict(img).squeeze()

def get_config(self) -> Dict[str, Any]:
# Required by a parent class. Will not be used in this context.
return {}


# -----------------------------------------------------------------------------
class Saliency:
def __init__(self, model, name, params):
Expand Down Expand Up @@ -160,7 +176,12 @@ def run(self, input, *_):
class SimilaritySaliency(Saliency):
def run(self, reference, query):
self._saliency.fill = FILL
sal = self._saliency(reference, [query], DescrModel(self._model))
model = (
DescrModel(self._model)
if not isinstance(self._model, TransformersModel)
else TransformersDescrModel(self._model)
)
sal = self._saliency(reference, [query], model)
return {
"type": "similarity",
"saliency": sal,
Expand Down

0 comments on commit 13b744e

Please sign in to comment.