-
Notifications
You must be signed in to change notification settings - Fork 334
Implementation of Focal Cutout or Focal Masking #458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks Sayak! Looks like a great tool. I'd love to be able to reproduce masked autoencoders using KerasCV components. Thanks! |
My intuition regarding this augmentation was that we would take a few consecutive patches and black out the rest of the image. However, in their official implementation (here) they have simply used inception crop and other augmentations. I'm not sure whether that is what we want. /cc @sayakpaul |
Hey folks, I took a look at the paper and the code. These are my thoughts:
I would love to chime in for this augmentation layer if we ever wanted to build it in keras-cv. |
Thanks for your interest and for pointing that out, @ariG23498! @AdityaKane2001 is actually implementing this as a part of GSoC, so I'm not too sure about the scope of taking another contributor on this one. |
I understand! Can we have a label for @AdityaKane2001 all the very best for your contributions! |
@LukeWood let's create a label for GSoC-22-related stuff to better separate them? |
Thanks for the insight! I just have a small question. It looks to me that patches are getting masked randomly, and not in the order mentioned in the paper. As you said, the question of RandomResizedCrop also remains, since they have used it with a very low area factor (0.05 to 0.3) which seems quite unusual. Code snippet for reference: if patch_drop > 0:
patch_keep = 1. - patch_drop
T_H = int(np.floor((x.shape[1]-1)*patch_keep))
perm = 1 + torch.randperm(x.shape[1]-1)[:T_H] # keep class token
idx = torch.cat([torch.zeros(1, dtype=perm.dtype, device=perm.device), perm])
x = x[:, idx, :] |
I guess it's best to clarify this with the authors of the paper. @imisra @MidoAssran could you please help us with the doubts mentioned in #458 (comment) and #458 (comment)? For context, @AdityaKane2001 is trying to implement focal masking (as done in MSN) in KerasCV for allowing the users to experiment with different masking strategies and study their impact in pre-training schemes. Thanks! |
Hi @AdityaKane2001, focal masking is just extracting a small-crop from an image-view. An image-view is created via random-data augmentations of an image. For efficiency, you can do both simultaneously in the data-loader with the RandomResizedCrop function in Pytorch. Key points: On random masking: This code, here, ensures that patch-dropping only happens to the random mask views, and not the focal views (which were already masked in the data-loader). |
Thanks for the clarification! I just have one question. The illustration in the paper suggests involvement of patching, and contiguous patches in a grid are retained while dropping the rest of the image in the case of focal masking. However, the code as well as the procedure you mentioned does not take this into consideration. Could you please share your thoughts on this? |
If you wish to explicitly separate the image-view generation from focal masking for conceptual reasons, you can create the image-views for the focal crops using RandomResizeCrop to a size of 224x224 pixels with a scale range of approximately [0.1, 0.7] (i.e., just multiply the current range by 224/96), and then randomly keep a block of patches (6x6 block for the /16 networks), and that should give you the same behaviour. However, one can simply combine those two steps from an implementation perspective to reproduce the same behaviour while improving efficiency. |
Thanks a lot for the clarification. It is clear now. |
…le (keras-team#458) * added Jax distributed training exammple using a Keras model * fixed file formatting * fixed file formatting * the order of arguments in stateless_appply has changed. Fixed example.
This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you. |
Thanks for reporting the issue! We have consolidated the development of KerasCV into the new KerasHub package, which supports image, text, and multi-modal models. Please read keras-team/keras-hub#1831. KerasHub will support all the core functionality of KerasCV. KerasHub can be installed with !pip install -U keras-hub. Documentation and guides are available at keras.io/keras_hub. With our focus shifted to KerasHub, we are not planning any further development or releases in KerasCV. If you encounter a KerasCV feature that is missing from KerasHub, or would like to propose an addition to the library, please file an issue with KerasHub. |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
Not sure if any other work has implemented and investigated this approach of Focal Masking before but [1] combines Focal Masking and Random Masking to improve self-supervised pre-training for learning visual features.
The idea of Focal Masking (and comparison to Random Masking) is demonstrated in the figure below:

Source: [1]Masking strategies have been super useful for NLP (BERT). They have shown good promise for Vision too (Masked Autoencoders, BeiT, SimMIM, Data2Vec, etc.). I believe it'd be useful for KerasCV users because it would allow them to pre-train models using different masking strategies and investigate their impact.
Random Masking is quite similar to Cutout (TensorFlow Addons module) and that is why I used the term Focal Cutout.
References
[1] Masked Siamese Networks: https://arxiv.org/abs/2204.07141
/cc: @LukeWood
The text was updated successfully, but these errors were encountered: