Skip to content

Implementation of Focal Cutout or Focal Masking #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sayakpaul opened this issue May 29, 2022 · 15 comments
Open

Implementation of Focal Cutout or Focal Masking #458

sayakpaul opened this issue May 29, 2022 · 15 comments

Comments

@sayakpaul
Copy link
Contributor

sayakpaul commented May 29, 2022

Not sure if any other work has implemented and investigated this approach of Focal Masking before but [1] combines Focal Masking and Random Masking to improve self-supervised pre-training for learning visual features.

The idea of Focal Masking (and comparison to Random Masking) is demonstrated in the figure below:

image

Source: [1]

Masking strategies have been super useful for NLP (BERT). They have shown good promise for Vision too (Masked Autoencoders, BeiT, SimMIM, Data2Vec, etc.). I believe it'd be useful for KerasCV users because it would allow them to pre-train models using different masking strategies and investigate their impact.

Random Masking is quite similar to Cutout (TensorFlow Addons module) and that is why I used the term Focal Cutout.

References

[1] Masked Siamese Networks: https://arxiv.org/abs/2204.07141

/cc: @LukeWood

@LukeWood
Copy link
Contributor

Thanks Sayak! Looks like a great tool. I'd love to be able to reproduce masked autoencoders using KerasCV components. Thanks!

@AdityaKane2001
Copy link
Contributor

@LukeWood

My intuition regarding this augmentation was that we would take a few consecutive patches and black out the rest of the image. However, in their official implementation (here) they have simply used inception crop and other augmentations. I'm not sure whether that is what we want.

/cc @sayakpaul

@ariG23498
Copy link
Contributor

Hey folks,

I took a look at the paper and the code. These are my thoughts:

  • @AdityaKane2001 rightly pointed out

    in their official implementation (here) they have simply used inception crop and other augmentations

  • The masking actually happens inside the encoder
  • So we first augment the images and them mask the patches accordingly. But I am still not sure why the authors have used RandomResizedCrop as an augmentation before masking. This part does not really sit with what it written in the paper and the diagram shown.

I would love to chime in for this augmentation layer if we ever wanted to build it in keras-cv.

@sayakpaul
Copy link
Contributor Author

Thanks for your interest and for pointing that out, @ariG23498!

@AdityaKane2001 is actually implementing this as a part of GSoC, so I'm not too sure about the scope of taking another contributor on this one.

@ariG23498
Copy link
Contributor

I understand! Can we have a label for GSoC so that we point out which issues are already taken?

@AdityaKane2001 all the very best for your contributions!

@sayakpaul
Copy link
Contributor Author

@LukeWood let's create a label for GSoC-22-related stuff to better separate them?

@AdityaKane2001
Copy link
Contributor

AdityaKane2001 commented Jun 15, 2022

@ariG23498

Thanks for the insight! I just have a small question. It looks to me that patches are getting masked randomly, and not in the order mentioned in the paper. As you said, the question of RandomResizedCrop also remains, since they have used it with a very low area factor (0.05 to 0.3) which seems quite unusual.

Code snippet for reference:

        if patch_drop > 0:
            patch_keep = 1. - patch_drop
            T_H = int(np.floor((x.shape[1]-1)*patch_keep))
            perm = 1 + torch.randperm(x.shape[1]-1)[:T_H]  # keep class token
            idx = torch.cat([torch.zeros(1, dtype=perm.dtype, device=perm.device), perm])
            x = x[:, idx, :]

@sayakpaul
Copy link
Contributor Author

sayakpaul commented Jun 15, 2022

I guess it's best to clarify this with the authors of the paper.

@imisra @MidoAssran could you please help us with the doubts mentioned in #458 (comment) and #458 (comment)?

For context, @AdityaKane2001 is trying to implement focal masking (as done in MSN) in KerasCV for allowing the users to experiment with different masking strategies and study their impact in pre-training schemes.

Thanks!

@MidoAssran
Copy link

Hi @AdityaKane2001, focal masking is just extracting a small-crop from an image-view. An image-view is created via random-data augmentations of an image. For efficiency, you can do both simultaneously in the data-loader with the RandomResizedCrop function in Pytorch.

Key points:
Notice that the crop-scale is very small (0.05, 0.3); meaning we are extracting crops that range between 5% to 30% of the total image size, and then resize these to 96x96 pixels for efficient batch processing (so that all the focal crops can be processed in the same forward pass). This is simply equivalent to called RandomResizedCrop with the aforementioned scale and crop-size!

On random masking:
Random masking on the other hand cannot be implemented in the same way, since it corresponds to dropping non-contiguous patches. Therefore, after creating an image-view, the random masking is executed in the encoder by randomly dropping input patches.

This code, here, ensures that patch-dropping only happens to the random mask views, and not the focal views (which were already masked in the data-loader).

@AdityaKane2001
Copy link
Contributor

@MidoAssran

Thanks for the clarification! I just have one question. The illustration in the paper suggests involvement of patching, and contiguous patches in a grid are retained while dropping the rest of the image in the case of focal masking. However, the code as well as the procedure you mentioned does not take this into consideration. Could you please share your thoughts on this?

@MidoAssran
Copy link

@AdityaKane2001

If you wish to explicitly separate the image-view generation from focal masking for conceptual reasons, you can create the image-views for the focal crops using RandomResizeCrop to a size of 224x224 pixels with a scale range of approximately [0.1, 0.7] (i.e., just multiply the current range by 224/96), and then randomly keep a block of patches (6x6 block for the /16 networks), and that should give you the same behaviour.

However, one can simply combine those two steps from an implementation perspective to reproduce the same behaviour while improving efficiency.

@AdityaKane2001
Copy link
Contributor

@MidoAssran

Thanks a lot for the clarification. It is clear now.

freedomtan pushed a commit to freedomtan/keras-cv that referenced this issue Jul 20, 2023
…le (keras-team#458)

* added Jax distributed training exammple using a Keras model

* fixed file formatting

* fixed file formatting

* the order of arguments in stateless_appply has changed. Fixed example.
Copy link

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.

@sachinprasadhs
Copy link
Collaborator

Thanks for reporting the issue! We have consolidated the development of KerasCV into the new KerasHub package, which supports image, text, and multi-modal models. Please read keras-team/keras-hub#1831. KerasHub will support all the core functionality of KerasCV.

KerasHub can be installed with !pip install -U keras-hub. Documentation and guides are available at keras.io/keras_hub.

With our focus shifted to KerasHub, we are not planning any further development or releases in KerasCV. If you encounter a KerasCV feature that is missing from KerasHub, or would like to propose an addition to the library, please file an issue with KerasHub.

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Apr 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants