diff --git a/ElasticViTMAE.py b/ElasticViTMAE.py new file mode 100644 index 0000000..5622068 --- /dev/null +++ b/ElasticViTMAE.py @@ -0,0 +1,441 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +import numpy as np +import pandas as pd +import numpy as np +from functools import partial + +import torch +import torch.nn as nn +from timm.models.vision_transformer import Block +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin + + +class PatchEmbed(nn.Module, PyTorchModelHubMixin): + """ 2D Image to Patch Embedding + + flux capactitor - allows us to trace mask + """ + + def __init__( + self, + img_size=[224, 224], + patch_size=[224,1], + in_chans=1, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]})." + assert W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class CustomHead(nn.Module, PyTorchModelHubMixin): + + def __init__(self, head_embed_dim, img_channels, patch_shape, patchify_fn, unpatchify_fn, num_classes): + super().__init__() + + self.head_embed_dim = head_embed_dim + self.img_channels = img_channels + self.patch_shape = patch_shape + self.patchify_fn = patchify_fn + self.unpatchify_fn = unpatchify_fn + self.num_classes = num_classes + + self.conv2d_embed = nn.Linear(self.head_embed_dim, self.patch_shape[0]*self.patch_shape[1], bias=True) + + self.conv2d_1 = torch.nn.Conv2d(self.img_channels, 256, kernel_size=7, padding=3, bias=False) + self.conv2d_1_rl = torch.nn.LeakyReLU() + self.conv2d_1_bn = torch.nn.BatchNorm2d(256) + + self.conv2d_2 = torch.nn.Conv2d(256, 128, kernel_size=5, padding=2, bias=False) + self.conv2d_2_rl = torch.nn.LeakyReLU() + self.conv2d_2_bn = torch.nn.BatchNorm2d(128) + + self.conv2d_3 = torch.nn.Conv2d(128, 64, kernel_size=3, padding=1, bias=False) + self.conv2d_3_rl = torch.nn.LeakyReLU() + self.conv2d_3_bn = torch.nn.BatchNorm2d(64) + + #for semantic segmentation use self.num_classes, otherwise use img_channels + self.conv2d_output = torch.nn.Conv2d(64, self.num_classes, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + + x = self.conv2d_embed(x) + x = x[:, 1:, :] # remove cls token + + x = self.unpatchify_fn(x) + + x = self.conv2d_1(x) + x = self.conv2d_1_rl(x) + x = self.conv2d_1_bn(x) + + x = self.conv2d_2(x) + x = self.conv2d_2_rl(x) + x = self.conv2d_2_bn(x) + + x = self.conv2d_3(x) + x = self.conv2d_3_rl(x) + x = self.conv2d_3_bn(x) + + x = self.conv2d_output(x) + + #apply softmax to get probabilities of classes + #x = F.softmax(x, dim=1) + + # (BCWH -> BWHC) move channels to last + + #apply torch.argmax() to channels dim + #x = torch.argmax(x, dim=-1) + + return x + + +class ElasticViTMAE(nn.Module, PyTorchModelHubMixin): + """ Masked Autoencoder with VisionTransformer backbone + """ + def __init__(self, img_size=[224, 224], patch_size=[16, 16], in_chans=3, + embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, + custom_head=False, full_image_loss=True, classes=10): + super().__init__() + + self.in_chans = in_chans + self.custom_head = custom_head + self.full_image_loss = full_image_loss + self.classes = classes + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + self.num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding + + self.blocks = nn.ModuleList([ + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)# qk_scale=False, + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList([ + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=False, + for i in range(decoder_depth)]) + + self.decoder_norm = norm_layer(decoder_embed_dim) + + if self.custom_head: + self.decoder_pred = CustomHead(head_embed_dim=decoder_embed_dim, + img_channels=self.in_chans, + patch_shape=self.patch_embed.patch_size, + patchify_fn=self.patchify, + unpatchify_fn=self.unpatchify, + num_classes = self.classes) # custom decoder to patch + else: + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size[0] * patch_size[1] * in_chans, bias=True) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=.02) + torch.nn.init.normal_(self.mask_token, std=.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + + p_h = self.patch_embed.patch_size[0] + p_w = self.patch_embed.patch_size[1] + + h = imgs.shape[2] // p_h + w = imgs.shape[3] // p_w + + # x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p_h, w, p_w)) # one channel + x = torch.einsum('nchpwq->nhwpqc', x) + # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) # one channel + x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * self.in_chans)) + + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + + p_h = self.patch_embed.patch_size[0] + p_w = self.patch_embed.patch_size[1] + + h = self.patch_embed.grid_size[0] + w = self.patch_embed.grid_size[1] + + # x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w, self.in_chans)) # one channel + x = torch.einsum('nhwpqc->nchpwq', x) + # imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p_h, w * p_w)) # one channel + + return imgs + + def random_masking(self, x, patch_idx, len_keep): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + + N, L, D = x.shape # batch, length, dim + #len_keep = len_keep.item() + + # sort noise for each sample + ids_shuffle = torch.argsort(patch_idx, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, idx_shuffle, len_keep): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, idx_shuffle, len_keep) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + if not self.custom_head: + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + if self.full_image_loss: + loss = (pred - target) ** 2 + + batch_size = mask.shape[0] + num_patches = mask.shape[1] + loss = loss.sum() / batch_size / num_patches # loss on all patches + else: + m = mask.unsqueeze(-1) + p = pred * m + t = target * m + loss = (p - t) ** 2 + + batch_size = mask.shape[0] + num_masked = mask[0].sum().item() + + loss = loss.sum() / batch_size / num_masked # loss on mask + + return loss + + def forward_loss_pixel_accuracy(self, pred, label): + out = torch.where(label==pred, 1, 0) + numerator = sum(out.flatten()) + denominator = len(out.flatten()) + accuracy = round((numerator/denominator),5) + loss = 1 - accuracy + return loss + + def forward(self, imgs, labels, idx_shuffle, len_keep): + latent, mask, ids_restore = self.forward_encoder(imgs, idx_shuffle, len_keep) + pred = self.forward_decoder(latent, ids_restore) + loss = self.forward_loss(imgs, pred, mask) + return loss, pred, mask + + +def build_2d_sincos_position_embedding(self, temperature=10000.): + h, w = self.patch_embed.grid_size + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h) + assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = self.embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature**omega) + out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) + out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) + pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] + + assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' + pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) + self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) + self.pos_embed.requires_grad = False + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + + return emb diff --git a/README.md b/README.md new file mode 100644 index 0000000..a0650b8 --- /dev/null +++ b/README.md @@ -0,0 +1,138 @@ +--- +license: apache-2.0 +tags: +- vision +- MAE +- model_hub_mixin +- pytorch_model_hub_mixin +datasets: +- patch-the-planet +--- + +# Model Card for ThinkOnward's Geophysical Foundation Model + +This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: + +This is a model based on [Meta's ViTMAE model](https://huggingface.co/facebook/vit-mae-base), with some modifications to the masking technique. The Geophyiscal Foundation Model, or GFM for short, uses the ViT architecture with masking on traces in 2D seismic images, rather than patches. + +## Model Details + +### Model Description + +ThinkOnward's Geophysical Foundation Model is a pre-trained a Vision Transformer pre-trained on 450 synthetically generated Synthoseis 3D seismic volumes. We use a new elastic architecture and trace masking process to fine-tune the Geophysical Foundation Model for the downstream task +of seismic interpolation. We use 50 3D seismic volumes from the Patch the Planet Challenge, hosted by ThinkOnward as our benchmark hold-out dataset. **Using a Structural Similarity Index Metric (SSIM) to +compare results we document the Geophysical Foundation Model is 2-3 times better than Shang et al. (2023), and similar to Lasscock et al. (2024).** + +- **Developed by:** Ognjen Tanovic and Mike McIntire of ThinkOnward (Shell Portfolio Company) +- **Model type:** MAE +- **License:** Apache 2.0 +- **Based on:** facebook/vit-mae-base + +### Model Sources + +Link to the model repository listed below. This model was also presented as a poster at the AAPG/SEG IMAGE Conference in Houston, Texas August 26th-29th, 2024. + +- **Repository:** https://github.com/thinkonward +- **Conference Poster Abstract:** https://imageevent.aapg.org/portals/26/abstracts/2024/4092088.pdf + +## Uses + +### Direct Use + +This model is a modified version the [ViT MAE](https://huggingface.co/docs/transformers/en/model_doc/vit_mae) architecture. The model was used to pretrain a backbone using 450 synthetically generated seismic volumes. The goal of this project is to demonstrate that Vision Transformers (ViT) with Masked Autoencoders (MAE) can be used to leverage large amounts of unlabeled seismic data through masking to train an encoder to recognize specfic features in seismic data. The pretrained backbone can then be used with a specific downstream task like interpolation, denoising, and segmentation. + +### Downstream Use + +Downstream tasks include: + + Regression: + - Interpolation of missing sections of seismic images + - Denoising seismic data + - Inversion (planned) + Classification: + - Segmentation of horizons + - Segmentation of faults (in progress) + - Segmentation of geobodies (in progress) + +### Out-of-Scope Use + +The backbone of this model was trained using 3D seismic data from the Patch the Planet Challenge hosted by ThinkOnward. Use of this model on anything outside of seismic data, or similar technologies would be out-of-scope and likely have poor performance. + +## How to Get Started with the Model + +You can load the model using: + +```python +import torch +from huggingface_hub import hf_hub_download + +# For root directory +model_path = hf_hub_download("thinkonward/geophysical-foundation-model", "elasticvitmae.bin") + +ElasticVitMAE = torch.load(model_path) +``` + +Once the mode architecture has been defined, you can use `.from_pretrained()` to extract weights! + +```python +model = ElasticViTMAE.from_pretrained("thinkonward/geophysical-foundation-model") +``` + +## Training Details + +### Training Data + +The data used to train the Geophysical Foundation Model was 450 synthetically generated seismic volumes. The data was generated using the [Synthoseis package](https://github.com/sede-open/synthoseis), which is a synthetic seismic data generator. The data was generated using the default rock properties model in the code repository. The data was genereated for the [Patch the Planet Challenge](https://thinkonward.com/app/c/challenges/patch-the-planet), hosted by ThinkOnward. + +**Training Dataset Card:** [patch-the-planet](https://huggingface.co/datasets/thinkonward/patch-the-planet) + +## Evaluation + +#### Testing Data + +Test data was generated using the same Synthoseis package as the training data. The test data was generated using the same rock properties model as the training data. The test data was generated for the [Patch the Planet Challenge](https://thinkonward.com/app/c/challenges/patch-the-planet), hosted by ThinkOnward. + +**Benchmark Dataset Card:** [patch-the-planet-benchmark](https://huggingface.co/datasets/thinkonward/patch-the-planet-benchmark) + +#### Metrics + +**Structural Similarity Index (SSIM)** - The primary metric for comparison of interpolation results is the `scikit-image` implementation of the [Structural Similarity Index](https://scikit-image.org/docs/stable/auto_examples/transform/plot_ssim.html). The Structural Similarity Index is a metric used to measure the similarity between two images. When the SSI equals 1, the images are identical. When the SSI equals 0, the images are completely dissimilar. Please refer to the `scikit-image` docs for more information about the metric, as well as examples of implementation. Similarity will be calculated for all predictions. The minimum and maximum SSI values will be dropped, and the mean SSI score across all predictions will be the final score. + +**Mean Squared Error (MSE):** - The Mean Squared Error is a metric used as a loss metric for this model to measure the average of the squares of the errors between the true and predicted values. The lower the MSE, the better the model is at predicting the values. MSE is used for regression tasks. + +**Cross Entropy Loss:** - The Cross Entropy Loss is a metric was used as a loss metric for this model to measure the average of the loss function for all predictions. The lower the Cross Entropy Loss, the better the model is at predicting the values. Cross Entropy Loss is used for downstream classification and segmentation tasks. + +### Results + +We use 50 3D seismic volumes from the Patch the Planet Challenge, hosted by ThinkOnward as our benchmark hold-out dataset. Using a Structural Similarity Index Metric (SSIM) to +compare results we document the Geophysical Foundation Model is 2-3 times better than Shang et al. (2023), and similar to Lasscock et al. (2024). + + +### Model Architecture and Objective + +![image](src_imgs/src_imgs_model_architecture.png) + +This model uses a modified version of the ViT MAE architecture. The model uses a masking technique on traces in 2D seismic images, rather than patches + +## Citations + +This model was released in conjunction with the presentation of a poster at the 2024 IMAGE Conference in Houston, Texas (August 26-29th, 2024) + +**APA:** + +McIntire, M., Tanovic, O., Mazura, J., Suurmeyer, N., & Pisel, J. (n.d.). Geophysical Foundation Model: Improving results with trace masking. In https://imageevent.aapg.org/portals/26/abstracts/2024/4092088.pdf. 2024 IMAGE Conference, Houston, United States of America. + +**BibTex:** + +@misc {thinkonward_2024, + author = { {ThinkOnward} }, + title = { geophysical-foundation-model (Revision 2f8d6ce) }, + year = 2024, + url = { https://huggingface.co/thinkonward/geophysical-foundation-model }, + doi = { 10.57967/hf/2908 }, + publisher = { Hugging Face } +} + +## Model Card Contact + +Please contact `challenges@thinkonward.com` for questions, comments, or concerns about this model. diff --git a/src_imgs/src_imgs_model_architecture.png b/src_imgs/src_imgs_model_architecture.png new file mode 100644 index 0000000..6c5b64a Binary files /dev/null and b/src_imgs/src_imgs_model_architecture.png differ