Skip to content

Commit

Permalink
Add EnumratedShapes ANE compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Seba committed Dec 30, 2024
1 parent 20e3572 commit 46b73ba
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 34 deletions.
80 changes: 67 additions & 13 deletions convert.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,59 @@
import sys
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import coremltools as ct
from model import Model, MaskedLMModel
from model import Model, MaskedLMModel, ModernBertRotaryEmbedding

"""
Convert a ModernBERT HuggingFace model to CoreML.
"""

torch.set_grad_enabled(False)

class ModelWrapper(nn.Module):
def __init__(self, model: Model, max_length=8192):
super().__init__()
self.model: Model = model
self.global_rope = ModernBertRotaryEmbedding(model.config.hidden_size // model.config.num_heads, model.config.global_rope_theta)
self.local_rope = ModernBertRotaryEmbedding(model.config.hidden_size // model.config.num_heads, model.config.local_rope_theta)
self.global_cos, self.global_sin = self.global_rope(torch.zeros((1,)), torch.arange(max_length).unsqueeze(0))
self.local_cos, self.local_sin = self.local_rope(torch.zeros((1,)), torch.arange(max_length).unsqueeze(0))

self.global_cos = self.global_cos.squeeze(0) # .T
self.global_sin = self.global_sin.squeeze(0) # .T
self.local_cos = self.local_cos.squeeze(0) # .T
self.local_sin = self.local_sin.squeeze(0) # .T

def forward(self, input_ids, sequence_length):
mask = torch.full_like(input_ids, 1)
mask_x = torch.cumsum(mask, dim=1) - 1
local_sin = self.local_sin[mask_x].transpose(-1, -2)
local_cos = self.local_cos[mask_x].transpose(-1, -2)
global_sin = self.global_sin[mask_x].transpose(-1, -2)
global_cos = self.global_cos[mask_x].transpose(-1, -2)
mask_x = mask_x[:, None]
distances = torch.abs(mask_x - torch.permute(mask_x, (0, 2, 1)))
distances = distances <= (model.config.local_attention_window_size // 2)
zeros = torch.zeros_like(distances, dtype=torch.float16)
# Mask of over sequence length tokens is going to be all -inf,
# .softmax outputs NaNs for those positions, let's see if that causes issues
global_attention_mask = torch.where((mask_x < sequence_length), zeros, -torch.inf)
sliding_window_mask = torch.where((mask_x < sequence_length) & distances, zeros, -torch.inf)

# return local_sin

return self.model(
input_ids,
global_attention_mask,
local_sin=local_sin,
local_cos=local_cos,
global_sin=global_sin,
global_cos=global_cos,
sliding_window_mask=global_attention_mask,
)

model_name_or_path = "answerdotai/ModernBERT-base"
max_seq_len = 1024
if len(sys.argv) == 3:
Expand All @@ -25,28 +68,36 @@

print(f"Converting {model_name_or_path} to CoreML...")
model = MaskedLMModel.from_pretrained(model_name_or_path).eval()
model.rotate()
# model.layers = model.layers[:4]
# model.rotate()
wmodel = ModelWrapper(model)

input_ids = torch.zeros( (1, max_seq_len), dtype=torch.int)
input_ids[..., :] = 50283 # PAD
seq = torch.tensor([50281,510,5347,273,6181,310,50284,15,50282], dtype=torch.int)
input_ids[..., :seq.shape[-1]] = seq
mask = torch.zeros((1,1,max_seq_len,max_seq_len))
mask[:,:,seq.shape[-1]:,:] = -1e4
mask[:,:,:,seq.shape[-1]:] = -1e4
sequence_length = torch.tensor((10,), dtype=torch.int32)
# mask = torch.zeros((1,1,max_seq_len,max_seq_len))
# mask[:,:,seq.shape[-1]:,:] = -1e4
# mask[:,:,:,seq.shape[-1]:] = -1e4

output_name = "hidden_states" if isinstance(model, MaskedLMModel) else "logits"

input_shape = ct.EnumeratedShapes(shapes=[[1, 256], [1, 512], [1, 1024], [1, 2048]])

mlmodel= ct.convert(
torch.jit.trace(model, (input_ids, mask)),
# torch.jit.trace(model, (input_ids, mask)),
torch.jit.trace(wmodel, (input_ids, sequence_length)),
inputs=[
ct.TensorType(name="input_ids", shape=input_ids.shape, dtype=np.int32),
ct.TensorType(name="mask", shape=mask.shape, dtype=np.float16, default_value=np.zeros_like(mask).astype(np.float16)),
# ct.TensorType(name="input_ids", shape=input_ids.shape, dtype=np.int32),
ct.TensorType(name="input_ids", shape=input_shape, dtype=np.int32),
# ct.TensorType(name="mask", shape=mask.shape, dtype=np.float16, default_value=np.zeros_like(mask).astype(np.float16)),
ct.TensorType(name="sequence_length", shape=sequence_length.shape, dtype=np.int32),
],
outputs=[
ct.TensorType(name=output_name),
],
minimum_deployment_target=ct.target.macOS14,
minimum_deployment_target=ct.target.iOS16,
compute_precision=ct.precision.FLOAT16,
# For initial prediction:
compute_units=ct.ComputeUnit.CPU_AND_NE,
Expand All @@ -55,7 +106,8 @@

input_output_descriptions = {
"input_ids": "Indices of input sequence tokens in the vocabulary",
"mask": "Mask for defining which tokens should attend to each other. 0 means attend and large negative number (e.g. -1e4) means do not attend.",
# "mask": "Mask for defining which tokens should attend to each other. 0 means attend and large negative number (e.g. -1e4) means do not attend.",
"sequence_length": "Non padded number of tokens in input_ids",
"hidden_states": "Raw outputs from the model. Typically further processed by a task-specific head.",
"logits": "Un-normalized per-token predictions.",
}
Expand All @@ -67,11 +119,13 @@

mlmodel.user_defined_metadata["Source Model"] = model_name_or_path

mlmodel.save(f"{model_name_or_path.replace('/', '-')}-{max_seq_len}.mlpackage")
mlmodel.save(f"{model_name_or_path.replace('/', '-')}-{max_seq_len}-NoRot.mlpackage")

model = MaskedLMModel.from_pretrained(model_name_or_path).eval() # Reload non-rotated model.
coreml_out = torch.from_numpy(mlmodel.predict({"input_ids": input_ids.numpy(), "mask": mask.numpy()})[output_name])
torch_out = model(input_ids, mask)
# coreml_out = torch.from_numpy(mlmodel.predict({"input_ids": input_ids.numpy(), "mask": mask.numpy()})[output_name])
coreml_out = torch.from_numpy(mlmodel.predict({"input_ids": input_ids.numpy(), "sequence_length": sequence_length.numpy()})[output_name])
# torch_out = model(input_ids, mask)
torch_out = wmodel(input_ids, sequence_length)
# Sometime useful for debugging.
# print("CoreML Top 4\n", coreml_out.topk(4, dim=1))
# print("Torch Top 4", torch_out.topk(4, dim=1))
Expand Down
88 changes: 67 additions & 21 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
x = self.embeddings(input_ids) # ~570us on CPU for 512 tokens.
if self.config.match_hf and isinstance(self.norm, LayerNorm):
# bc1s LayerNorm introduces a slight numerical accuracy drift.
return F.layer_norm(x, (x.size()[-1],), self.norm.weight, self.norm.bias, self.norm.eps).transpose(-1,-2).unsqueeze(2)
# return F.layer_norm(x, (x.size()[-1],), self.norm.weight, self.norm.bias, self.norm.eps).transpose(-1,-2).unsqueeze(2)
return F.layer_norm(x, (self.config.hidden_size,), self.norm.weight, self.norm.bias, self.norm.eps).transpose(-1,-2).unsqueeze(2)
x = x.transpose(-1,-2).unsqueeze(2) # to bc1s
return self.norm(x)

Expand All @@ -58,7 +59,7 @@ def __init__(self, config: Config, layer_index: int):

self.out = nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=1, bias=False)

def forward(self, x, position_ids, attention_mask, sliding_window_mask=None):
def forward(self, x, position_ids, attention_mask, sliding_window_mask=None, sin=None, cos=None):
"""
x: (bs, hidden_size, 1, seq_length)
position_ids: (bs, seq_length)
Expand All @@ -71,14 +72,16 @@ def forward(self, x, position_ids, attention_mask, sliding_window_mask=None):
qkv = F.linear(x.squeeze(2).transpose(-2,-1), self.qkv.weight.squeeze()).transpose(-2,-1).unsqueeze(2)
else:
qkv = self.qkv(x)
return self.split_attention(qkv, self._attention_mask(attention_mask, sliding_window_mask), sin, cos, headdim=self.dim_head)
q,k,v = qkv.chunk(3, dim=1)
q = q.view(b, self.config.num_heads, self.dim_head, s)
k = k.view(b, self.config.num_heads, self.dim_head, s)
v = v.view(b, self.config.num_heads, self.dim_head, s)

# RoPE
cos, sin = self.rotary_emb(x, position_ids=position_ids)
cos, sin = cos.transpose(-1,-2), sin.transpose(-1 ,-2)
if sin is None:
cos, sin = self.rotary_emb(x, position_ids=position_ids)
cos, sin = cos.transpose(-1,-2), sin.transpose(-1 ,-2)
q, k = apply_rotary_pos_emb(q, k, cos, sin)

# Switch between global or local attention as appropriate.
Expand All @@ -87,6 +90,33 @@ def forward(self, x, position_ids, attention_mask, sliding_window_mask=None):
attn = self.original_attn(q, k, v, mask, self.config.num_heads, self.dim_head)
return self.out(attn)

def split_attention(self, qkv, attention_mask, sin, cos, headdim):
"""
I think attentions that use reshape are not Enumerated Shape compatible, thus we
have to use this which splits the heads into many tensors instead of using reshapes
"""
qkv = qkv.squeeze(2)
all_heads = torch.split(qkv, self.dim_head, 1) # [(bsz, hidden_dim, length)] * 3 * num_heads
qheads, kheads, vheads = all_heads[:self.config.num_heads], all_heads[self.config.num_heads:-self.config.num_heads], all_heads[-self.config.num_heads:]
attns = []
for i in range(self.config.num_heads):
qh = qheads[i]
kh = kheads[i]
vh = vheads[i]

qh, kh = apply_rotary_pos_emb(qh, kh, cos, sin, headdim=headdim)
scores = qh.transpose(-1, -2) @ kh
scores *= self.dim_head
scores += attention_mask
weights = scores.softmax(-1)
attn = vh @ weights
attns.append(attn) # (1, headdim, sequence_length)

attn = torch.concat(attns, dim=1).unsqueeze(-2)
out = self.out(attn)
return out


@staticmethod
def original_attn(q, k, v, mask, heads, dim_head):
bs = q.size(0)
Expand Down Expand Up @@ -143,6 +173,7 @@ def forward(self, x):
class Block(nn.Module):
def __init__(self, config: Config, layer_index: int):
super().__init__()
self.config = config
self.layer_index = layer_index
self.pre_attn_norm = LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) if layer_index != 0 else nn.Identity()
self.attn = Attention(config, layer_index)
Expand All @@ -151,8 +182,14 @@ def __init__(self, config: Config, layer_index: int):
# Optional transform for residual connection. Useful if applying QuaRot-style rotations.
self.residual_transform = nn.Identity()

def forward(self, x, position_ids, attention_mask, sliding_window_mask=None):
x = self.residual_transform(x) + self.attn(self.pre_attn_norm(x), position_ids, attention_mask, sliding_window_mask)
def forward(self, x, position_ids, attention_mask, sliding_window_mask=None, global_sin=None, global_cos=None, local_sin=None, local_cos=None):
if self.layer_index % self.config.global_attn_every_n_layers == 0:
sin = global_sin
cos = global_cos
else:
sin = local_sin
cos = local_cos
x = self.residual_transform(x) + self.attn(self.pre_attn_norm(x), position_ids, attention_mask, sliding_window_mask, sin=sin, cos=cos)
return x + self.mlp(self.pre_mlp_norm(x))

class MaskedLMHead(nn.Module):
Expand All @@ -178,12 +215,17 @@ def __init__(self, config: Config, head: Optional[nn.Module]=None):
self.head = head(config) if head else nn.Identity()
self.unrotate = nn.Identity()

def forward(self, x, attention_mask):
sliding_window_mask = Attention.sliding_window_mask(self.config, attention_mask) # Do all CPU work first.
def forward(self, x, attention_mask, global_sin=None, global_cos=None, local_sin=None, local_cos=None, sliding_window_mask=None):
if sliding_window_mask is None:
sliding_window_mask = Attention.sliding_window_mask(self.config, attention_mask) # Do all CPU work first.
x = self.embeddings(x)
position_ids = torch.arange(x.shape[-1], device=x.device).unsqueeze(0)
if global_sin is None:
position_ids = torch.arange(x.shape[-1], device=x.device).unsqueeze(0)
else:
position_ids = None
for layer in self.layers:
x = layer(x, position_ids, attention_mask, sliding_window_mask)
x = layer(x, position_ids, attention_mask, sliding_window_mask, global_sin=global_sin, global_cos=global_cos, local_sin=local_sin, local_cos=local_cos)
return x
x = self.ln_f(x)
x = self.unrotate(x)
return self.head(x) # MaskedLM, Classification, etc. or no-op
Expand Down Expand Up @@ -269,6 +311,9 @@ def forward(self, inputs):
out = out * self.weight.view(1, self.num_channels, 1, 1)

return out

def stable_forward(self, inputs):
pass

def __repr__(self):
return f'LayerNorm(({self.num_channels},), eps={self.eps}, elementwise_affine={self.elementwise_affine})'
Expand Down Expand Up @@ -332,17 +377,18 @@ def forward(self, x, position_ids, seq_len=None):
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def rotate_half(x):
# Modified for BC1S tensor shape.
x1 = x[:, :, : x.shape[-2] // 2, :] # (B, nh, hs/2, T)
x2 = x[:, :, x.shape[-2] // 2 :, :] # (B, nh, hs/2, T)
return torch.cat((-x2, x1), dim=-2) # (B, nh, hs, T)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
def rotate_half(x, headdim=None):
if headdim is None:
headdim = x.shape[-2]
x1 = x[:, : headdim // 2, :]
x2 = x[:, headdim // 2 :, :]
return torch.cat((-x2, x1), dim=-2)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1, headdim=None):
# cos = cos.unsqueeze(unsqueeze_dim)
# sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q, headdim=headdim) * sin)
k_embed = (k * cos) + (rotate_half(k, headdim=headdim) * sin)
return q_embed, k_embed

# Orthogonal Rotation
Expand Down

0 comments on commit 46b73ba

Please sign in to comment.