Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate SAM (segment anything) encoder with Unet #757

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor sam decoder init
  • Loading branch information
Rustem Galiullin committed May 3, 2023
commit 85565ce3d9f612303403536a531ccb85c699eb13
39 changes: 27 additions & 12 deletions segmentation_models_pytorch/decoders/sam/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Union, List, Tuple

import torch
from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a pip package? probably need to add to reqs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just added it to reqs, or should we make it optional?

from torch.nn import functional as F

from segmentation_models_pytorch.base import (
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
encoder_depth: int = None,
encoder_weights: Optional[str] = "sam-vit_h",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_channels: List[int] = 256,
decoder_attention_type: Optional[str] = None,
in_channels: int = 3,
image_size: int = 1024,
Expand All @@ -71,14 +72,9 @@ def __init__(
aux_params: Optional[dict] = None,
):
super().__init__()
from segment_anything import sam_model_registry

sam = sam_model_registry[encoder_name[4:]](
checkpoint=encoder_weights, image_size=image_size, vit_patch_size=vit_patch_size
)

self.pixel_mean = sam.pixel_mean
self.pixel_std = sam.pixel_std
self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
self.pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)

self.encoder = get_encoder(
encoder_name,
Expand All @@ -87,13 +83,32 @@ def __init__(
weights=encoder_weights,
img_size=image_size,
patch_size=vit_patch_size,
out_chans=decoder_channels,
)
self.prompt_encoder = sam.prompt_encoder

self.decoder = sam.mask_decoder
image_embedding_size = image_size // vit_patch_size
self.prompt_encoder = PromptEncoder(
embed_dim=decoder_channels,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
)

self.decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=decoder_channels,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=decoder_channels,
iou_head_depth=3,
iou_head_hidden_dim=256,
)

self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
in_channels=decoder_channels,
out_channels=classes,
activation=activation,
kernel_size=3,
Expand Down Expand Up @@ -155,7 +170,7 @@ def forward(self, x):
x = torch.stack([self.preprocess(img) for img in x])
features = self.encoder(x)
sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
low_res_masks, iou_preidctions = self.decoder(
low_res_masks, iou_predictions = self.decoder(
image_embeddings=features,
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_sam_encoder(encoder_name, img_size, patch_size):
assert out.size() == torch.Size([1, 256, expected_patches, expected_patches])


@pytest.mark.parametrize("encoder_name", ["sam-vit_b"])
@pytest.mark.parametrize("image_size", [64])
@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"])
@pytest.mark.parametrize("image_size", [64, 128])
def test_sam(encoder_name, image_size):
model_class = smp.SAM
model = model_class(encoder_name, encoder_weights=None, image_size=image_size)
Expand Down