Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DTensor imports to use public module #152

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions src/olmo_core/nn/moe/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import warnings
from typing import Any, Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Shard, distribute_tensor

from ...distributed.utils import get_local_tensor
from ...exceptions import OLMoConfigurationError

__all__ = ["MoEMLP"]


class _ScaleGradient(torch.autograd.Function):
@staticmethod
@torch.amp.autocast_mode.custom_fwd(device_type="cuda")
def forward(ctx: Any, x: torch.Tensor, scale: float):
ctx.scale = scale
return x

@staticmethod
@torch.amp.autocast_mode.custom_bwd(device_type="cuda")
def backward(ctx: torch.Tensor, grad: torch.Tensor):
return grad * ctx.scale, None # type: ignore


_scale_gradient: Callable[[torch.Tensor, float], torch.Tensor] = _ScaleGradient.apply # type: ignore


class MoEMLP(nn.Module):
def __init__(
self,
*,
d_model: int,
hidden_size: int,
num_experts: int,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
):
super().__init__()
self.d_model = d_model
self.hidden_size = hidden_size
self.num_experts = num_experts

self.gradient_scale: Optional[float] = None
self.experts_per_rank = num_experts

self.w1 = nn.Parameter(
torch.empty(
num_experts,
hidden_size,
d_model,
device=init_device,
dtype=dtype,
),
)
self.w2 = nn.Parameter(
torch.empty(
num_experts,
hidden_size,
d_model,
device=init_device,
dtype=dtype,
),
)
self.w3 = nn.Parameter(
torch.empty(
num_experts,
hidden_size,
d_model,
device=init_device,
dtype=dtype,
),
)

self._gmm = None

try:
import grouped_gemm # type: ignore

self._gmm = grouped_gemm.ops.gmm
except ImportError:
warnings.warn(
"Grouped GEMM not available, so the MoE will be substantially slower. "
"Please install the 'grouped_gemm' package if possible.\n"
"https://github.com/tgale96/grouped_gemm"
)

def scale_grad(self, w: torch.Tensor) -> torch.Tensor:
if self.gradient_scale is None:
return w
return _scale_gradient(w, self.gradient_scale)

def gmm(
self, x: torch.Tensor, w: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool = False
) -> torch.Tensor:
if self._gmm is not None:
return self._gmm(x, w, batch_sizes, trans_b=trans_b)
else:
out = []
start = 0
for i, size in enumerate(batch_sizes.cpu().numpy()):
rhs = w[i, :, :].t() if trans_b else w[i, :, :]
out.append(x[start : start + size, :] @ rhs)
start += size
return torch.cat(out)

def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor:
"""
Compute the expert outputs.

:param x: The input of shape ``(total_tokens, d_model)``.
:param tokens_per_expert: Specifies how many tokens go to each expert. Should be a
1-D ``LongTensor``.
"""
# Scale gradients and get local tensors (in case of expert parallelism).
# shape (all): (experts_per_rank, hidden_size, d_model)
w1, w2, w3 = (
get_local_tensor(self.scale_grad(self.w1)),
get_local_tensor(self.scale_grad(self.w2)),
get_local_tensor(self.scale_grad(self.w3)),
)

# Compute the MLP.
x1 = self.gmm(x, w1, tokens_per_expert, trans_b=True)
x2 = self.gmm(x, w3, tokens_per_expert, trans_b=True)
x1 = F.silu(x1) * x2
return self.gmm(x1, w2, tokens_per_expert)

def apply_ep(self, ep_mesh: DeviceMesh):
"""
Apply expert parallelism.
"""
if self.num_experts % ep_mesh.size() != 0:
raise OLMoConfigurationError(
f"'num_experts' ({self.num_experts}) must be divisible by the expert parallel degree ({ep_mesh.size()})."
)

self.experts_per_rank = self.num_experts // ep_mesh.size()
self.gradient_scale = 1.0 / ep_mesh.size()

self.register_parameter("w1", nn.Parameter(distribute_tensor(self.w1, ep_mesh, [Shard(0)])))
self.register_parameter("w2", nn.Parameter(distribute_tensor(self.w2, ep_mesh, [Shard(0)])))
self.register_parameter("w3", nn.Parameter(distribute_tensor(self.w3, ep_mesh, [Shard(0)])))
212 changes: 212 additions & 0 deletions src/olmo_core/nn/moe/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple

import torch
import torch.nn as nn

from ...config import Config, DType, StrEnum
from ...exceptions import OLMoConfigurationError

__all__ = ["MoERouter", "MoELinearRouter", "MoERouterConfig", "MoERouterType"]


# NOTE: To enable end-to-end benchmarking without convergence we
# support a flag to force the router to assign tokens uniformly
# across the experts. We do this with a custom autograd operation
# so that PyTorch still executes the full set of router operation.
class _UniformExpertAssignment(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, x: torch.Tensor, num_experts: int):
del ctx
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
out = torch.remainder(out, num_experts)
return out.view(x.shape)


_uniform_expert_assignment: Callable[
[torch.Tensor, int], torch.Tensor
] = _UniformExpertAssignment.apply # type: ignore


class MoERouterType(StrEnum):
"""
An enumeration of the different MoE router implementations.
"""

default = "default"
"""
➡️ :class:`MoELinearRouter`
"""


@dataclass
class MoERouterConfig(Config):
"""
A configuration class for easily building any of the different MoE router modules.
"""

name: MoERouterType = MoERouterType.default
"""
The name of the implementation.
"""
num_experts: int = 1
top_k: int = 1
jitter_eps: Optional[float] = None
normalize_expert_weights: Optional[float] = None
uniform_expert_assignment: bool = False
bias: bool = True
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
"""
The number of params that the module will have once built.

:param d_model: The model dimensionality.
"""
num_params = 0
if self.name == MoERouterType.default:
num_params += d_model * self.num_experts
if self.bias:
num_params += self.num_experts
else:
raise NotImplementedError

return num_params

def build(self, d_model: int, *, init_device: str = "cpu") -> "MoERouter":
"""
Build the corresponding MoE router module.

:param d_model: The model dimensionality.
:param init_device: The device initialize the parameters on, e.g. "cpu", "meta".
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs.update(
dtype=kwargs.pop("dtype").as_pt(),
d_model=d_model,
init_device=init_device,
)

try:
if self.name == MoERouterType.default:
return MoELinearRouter(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class MoERouter(nn.Module):
"""
A base class for MoE router modules.

:param d_model: The model dimensionality (hidden size).
:param num_experts: The total number of experts.
:param top_k: The number of experts to assign to each token.
:param jitter_eps: Controls the amount of noise added to the input during training.
:param normalize_expert_weights: The type of norm (e.g. ``2.0`` for L2 norm) to use to normalize
the expert weights.
:param uniform_expert_assignment: Force uniform assignment. Useful for benchmarking.
"""

def __init__(
self,
*,
d_model: int,
num_experts: int,
top_k: int = 1,
jitter_eps: Optional[float] = None,
normalize_expert_weights: Optional[float] = None,
uniform_expert_assignment: bool = False,
):
super().__init__()
self.d_model = d_model
self.num_experts = num_experts
self.top_k = top_k
self.jitter_eps = jitter_eps
self.normalize_expert_weights = normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment

def jitter(self, x: torch.Tensor) -> torch.Tensor:
if self.jitter_eps is None or not self.training:
return x
else:
low = 1.0 - self.jitter_eps
high = 1.0 + self.jitter_eps
noise = torch.rand_like(x)
return x * (low + noise * (high - low))

def get_top_k(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.top_k == 1:
return scores.max(dim=-1, keepdim=True)
return torch.topk(scores, self.top_k, dim=-1)

@abstractmethod
def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor:
"""
Given the input ``x`` of shape ``(*, d_model)``, compute the expert scores.

:returns: The expert scores, shape ``(*, num_experts)``.
"""
raise NotImplementedError

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given the input ``x`` of shape ``(batch_size, seq_len, d_model)``, compute the
experts assignment.

:returns: The scores of shape ``(batch_size, seq_len, num_experts)``, the expert weights
of shape ``(batch_size, seq_len, top_k)``, and the expert indices of shape
``(batch_size, seq_len, top_k)``.
"""
# shape: (batch_size, seq_len, d_model)
x = self.jitter(x)

# shape: (batch_size * seq_len, num_experts)
scores = self.get_expert_scores(x.view(-1, self.d_model))

# shape: (batch_size * seq_len, top_k)
expert_weights, expert_indices = self.get_top_k(scores)

if self.normalize_expert_weights is not None:
expert_weights.div_(
torch.norm(
expert_weights,
p=self.normalize_expert_weights,
dim=-1,
keepdim=True,
)
)

if self.uniform_expert_assignment:
expert_indices = _uniform_expert_assignment(expert_indices, self.num_experts)

return scores, expert_weights, expert_indices


class MoELinearRouter(MoERouter):
"""
A simple, learned, linear router.
"""

def __init__(
self,
*,
bias: bool = True,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
**kwargs,
):
super().__init__(**kwargs)
self.w_score = nn.Linear(
self.d_model, self.num_experts, bias=bias, dtype=dtype, device=init_device
)

def get_expert_scores(self, x: torch.Tensor) -> torch.Tensor:
logits = self.w_score(x.view(-1, self.d_model))
# TODO: save router logits for Z-loss
return logits.softmax(dim=-1)
Loading