Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layer improvement #90

Merged
merged 26 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dc4c313
add learnable ple encodings
AnFreTh Jul 15, 2024
f6aa9ac
add rotary embeddings and attention_net utils
AnFreTh Jul 15, 2024
04101f9
add mambatab, rotaryft and rnn configs
AnFreTh Jul 15, 2024
fe40c47
adapt mambatab config to paper hparams
AnFreTh Jul 15, 2024
30d35e4
include RNN and MambaTab
AnFreTh Jul 15, 2024
21b0afb
add mambatab, rnn, rotery and basisexpandFT to init
AnFreTh Jul 15, 2024
8b19f92
include basemodel basisexpansion
AnFreTh Jul 15, 2024
bad3db6
adding util embedding alyer for decluttering
AnFreTh Jul 15, 2024
a2c3e71
adapted embedding layer for when no cat features
AnFreTh Jul 15, 2024
5f15d5f
adapted mambular, FT and TabTransformer for new embedding layer class
AnFreTh Jul 15, 2024
fa8cc00
included possible embeddings into MLP and ResNet
AnFreTh Jul 15, 2024
6efc326
adapted configs of models
AnFreTh Jul 15, 2024
ae0bf97
adapt documentation to new configs
AnFreTh Jul 15, 2024
6be79d1
add missing line in MLP
AnFreTh Jul 15, 2024
3d840b9
added seq len to emebedding layer
AnFreTh Jul 16, 2024
ae84257
added shuffling embeddings before being passed to mamba blocks
AnFreTh Jul 16, 2024
e1f5c40
shuffle embeddings option in config
AnFreTh Jul 16, 2024
8af86a2
shuffling along second axis
AnFreTh Jul 16, 2024
2029700
add .built attr to model classed
AnFreTh Jul 16, 2024
c7459c4
delete expansion
AnFreTh Jul 17, 2024
96e9449
delete basis expansion
AnFreTh Jul 17, 2024
864d4f0
delete self.paper attribute from mambatab
AnFreTh Jul 17, 2024
d636b28
adapt build model func
AnFreTh Jul 17, 2024
5a5a7cd
adapt build_method
AnFreTh Jul 17, 2024
283a10b
add embedding layer and delete unused models
AnFreTh Jul 26, 2024
19b760c
Merge branch 'develop' into layer_improvement
AnFreTh Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions mambular/arch_utils/attention_net_arch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch.nn as nn
import torch


import torch
import torch.nn as nn


class Reshape(nn.Module):
def __init__(self, j, dim, method="linear"):
super(Reshape, self).__init__()
self.j = j
self.dim = dim
self.method = method

if self.method == "linear":
# Use nn.Linear approach
self.layer = nn.Linear(dim, j * dim)
elif self.method == "embedding":
# Use nn.Embedding approach
self.layer = nn.Embedding(dim, j * dim)
elif self.method == "conv1d":
# Use nn.Conv1d approach
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
else:
raise ValueError(f"Unsupported method '{method}' for reshaping.")

def forward(self, x):
batch_size = x.shape[0]

if self.method == "linear" or self.method == "embedding":
x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)
elif self.method == "conv1d":
# For Conv1d, add dummy dimension and reshape
x = x.unsqueeze(-1) # Add dummy dimension for convolution
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)

return x_reshaped


class AttentionNetBlock(nn.Module):
def __init__(
self,
channels,
in_channels,
d_model,
n_heads,
n_layers,
dim_feedforward,
transformer_activation,
output_dim,
attn_dropout,
layer_norm_eps,
norm_first,
bias,
activation,
embedding_activation,
norm_f,
method,
):
super(AttentionNetBlock, self).__init__()

self.reshape = Reshape(channels, in_channels, method)

encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
batch_first=True,
dim_feedforward=dim_feedforward,
dropout=attn_dropout,
activation=transformer_activation,
layer_norm_eps=layer_norm_eps,
norm_first=norm_first,
bias=bias,
)

self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=n_layers,
norm=norm_f,
)

self.linear = nn.Linear(d_model, output_dim)
self.activation = activation
self.embedding_activation = embedding_activation

def forward(self, x):
z = self.reshape(x)
x = self.embedding_activation(z)
x = self.encoder(x)
x = z + x
x = torch.sum(x, dim=1)
x = self.linear(x)
x = self.activation(x)
return x
97 changes: 97 additions & 0 deletions mambular/arch_utils/attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch.nn as nn
import torch
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange
import torch.nn.functional as F
import numpy as np


class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)


def FeedForward(dim, mult=4, dropout=0.0):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.rotary = rotary
dim = np.int64(dim / 2)
self.rotary_embedding = RotaryEmbedding(dim=dim)

def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
if self.rotary:
q = self.rotary_embedding.rotate_queries_or_keys(q)
k = self.rotary_embedding.rotate_queries_or_keys(k)
q = q * self.scale

sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)

attn = sim.softmax(dim=-1)
dropped_attn = self.dropout(attn)

out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out(out)

return out, attn


class Transformer(nn.Module):
def __init__(
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
rotary=rotary,
),
FeedForward(dim, dropout=ff_dropout),
]
)
)

def forward(self, x, return_attn=False):
post_softmax_attns = []

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = attn_out + x
x = ff(x) + x

if not return_attn:
return x

return x, torch.stack(post_softmax_attns)
145 changes: 145 additions & 0 deletions mambular/arch_utils/embedding_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
import torch.nn as nn


class EmbeddingLayer(nn.Module):
def __init__(
self,
num_feature_info,
cat_feature_info,
d_model,
embedding_activation=nn.Identity(),
layer_norm_after_embedding=False,
use_cls=False,
cls_position=0,
):
"""
Embedding layer that handles numerical and categorical embeddings.

Parameters
----------
num_feature_info : dict
Dictionary where keys are numerical feature names and values are their respective input dimensions.
cat_feature_info : dict
Dictionary where keys are categorical feature names and values are the number of categories for each feature.
d_model : int
Dimensionality of the embeddings.
embedding_activation : nn.Module, optional
Activation function to apply after embedding. Default is `nn.Identity()`.
layer_norm_after_embedding : bool, optional
If True, applies layer normalization after embeddings. Default is `False`.
use_cls : bool, optional
If True, includes a class token in the embeddings. Default is `False`.
cls_position : int, optional
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.

Methods
-------
forward(num_features=None, cat_features=None)
Defines the forward pass of the model.
"""
super(EmbeddingLayer, self).__init__()

self.d_model = d_model
self.embedding_activation = embedding_activation
self.layer_norm_after_embedding = layer_norm_after_embedding
self.use_cls = use_cls
self.cls_position = cls_position

self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_shape, d_model, bias=False),
self.embedding_activation,
)
for feature_name, input_shape in num_feature_info.items()
]
)

self.cat_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
)
for feature_name, num_categories in cat_feature_info.items()
]
)

if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
if layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(d_model)

self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)

def forward(self, num_features=None, cat_features=None):
"""
Defines the forward pass of the model.

Parameters
----------
num_features : Tensor, optional
Tensor containing the numerical features.
cat_features : Tensor, optional
Tensor containing the categorical features.

Returns
-------
Tensor
The output embeddings of the model.

Raises
------
ValueError
If no features are provided to the model.
"""
if self.use_cls:
batch_size = (
cat_features[0].size(0)
if cat_features != []
else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

if self.cat_embeddings and cat_features is not None:
cat_embeddings = [
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None

if self.num_embeddings and num_features is not None:
num_embeddings = [
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
]
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None

if cat_embeddings is not None and num_embeddings is not None:
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
elif cat_embeddings is not None:
x = cat_embeddings
elif num_embeddings is not None:
x = num_embeddings
else:
raise ValueError("No features provided to the model.")

if self.use_cls:
if self.cls_position == 0:
x = torch.cat([cls_tokens, x], dim=1)
elif self.cls_position == 1:
x = torch.cat([x, cls_tokens], dim=1)
else:
raise ValueError(
"Invalid cls_position value. It should be either 0 or 1."
)

return x
38 changes: 38 additions & 0 deletions mambular/arch_utils/learnable_ple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import torch.nn as nn


class PeriodicLinearEncodingLayer(nn.Module):
def __init__(self, bins=10, learn_bins=True):
super(PeriodicLinearEncodingLayer, self).__init__()
self.bins = bins
self.learn_bins = learn_bins

if self.learn_bins:
# Learnable bin boundaries
self.bin_boundaries = nn.Parameter(torch.linspace(0, 1, self.bins + 1))
else:
self.bin_boundaries = torch.linspace(-1, 1, self.bins + 1)

def forward(self, x):
if self.learn_bins:
# Ensure bin boundaries are sorted
sorted_bins = torch.sort(self.bin_boundaries)[0]
else:
sorted_bins = self.bin_boundaries

# Initialize z with zeros
z = torch.zeros(x.size(0), self.bins, device=x.device)

for t in range(1, self.bins + 1):
b_t_1 = sorted_bins[t - 1]
b_t = sorted_bins[t]
mask1 = x < b_t_1
mask2 = x >= b_t
mask3 = (x >= b_t_1) & (x < b_t)

z[mask1.squeeze(), t - 1] = 0
z[mask2.squeeze(), t - 1] = 1
z[mask3.squeeze(), t - 1] = (x[mask3] - b_t_1) / (b_t - b_t_1)

return z
Loading
Loading