Skip to content

Commit

Permalink
add file; set default; rm branching
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 19, 2023
1 parent 9129059 commit 38cf337
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 98 deletions.
24 changes: 24 additions & 0 deletions megablocks/layers/activation_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Callable

import torch
import stk


def activation_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kwargs):
assert isinstance(x, stk.Matrix)
with torch.set_grad_enabled(return_grad_fn):
if return_grad_fn:
x.data.requires_grad = True
out = function(x.data, **kwargs)
y = stk.Matrix(
x.size(),
out,
x.row_indices,
x.column_indices,
x.offsets,
x.column_indices_t,
x.offsets_t,
x.block_offsets_t)
if return_grad_fn:
return y, out.backward
return y
3 changes: 2 additions & 1 deletion megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import megablocks.turbo_util as turbo
import megablocks.grouped_gemm_util as grouped_gemm
import torch
import torch.nn.functional as F
from typing import Callable, Optional, Union

# Type annotation for in-place Tensor initialization function.
Expand All @@ -19,7 +20,7 @@ class Arguments:
num_layers : int = 1
bias : bool = True
return_bias : bool = True
activation_fn : Optional[Callable] = None
activation_fn : Optional[Callable] = partial(F.gelu, approximate="tanh")

# MoE arguments.
moe_num_experts : int = 1
Expand Down
17 changes: 4 additions & 13 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from megablocks.layers import common
from megablocks.layers import gelu
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.activation_fn import activation_fn
from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights
from megablocks.layers import mpu
from megablocks.layers.arguments import Arguments, InitFn
from megablocks import grouped_gemm_util as gg
import stk
import torch
import torch.nn.functional as F


class SparseGLU(SparseMLP):
Expand Down Expand Up @@ -39,11 +37,8 @@ def forward(self, x, topo):
x1 = stk.ops.sdd(x, w1.t(), topo)
x2 = stk.ops.sdd(x, v1.t(), topo)

if self.args.activation_fn:
act_fn_out = act_fn(x1, self.args.activation_fn)
else:
act_fn_out = gelu.gelu(x1)
x1 = stk.ops.mul(act_fn_out, x2)
activation_fn_out = activation_fn(x1, self.args.activation_fn)
x1 = stk.ops.mul(activation_fn_out, x2)

return stk.ops.dsd(x1, w2)

Expand All @@ -61,9 +56,5 @@ def forward(self, x, tokens_per_expert):
# Compute the MLP.
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
if self.args.activation_fn:
act_fn_out = self.args.activation_fn(x1)
else:
act_fn_out = F.gelu(x1, approximate="tanh")
x1 = act_fn_out * x2
x1 = self.args.activation_fn(x1) * x2
return gg.ops.gmm(x1, w2, batch_sizes)
137 changes: 53 additions & 84 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from megablocks.layers import common
from megablocks.layers import gelu
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.activation_fn import activation_fn
from megablocks.layers import mpu
from megablocks.layers import weight_parallel as wp
from megablocks.layers.arguments import Arguments, InitFn
Expand Down Expand Up @@ -123,10 +122,7 @@ def scale_grad(self, w):

def forward(self, x):
x = torch.bmm(x, self.scale_grad(self.w1))
if self.args.activation_fn:
x = self.args.activation_fn(x)
else:
x = F.gelu(x, approximate="tanh")
x = self.args.activation_fn(x)
return torch.bmm(x, self.scale_grad(self.w2))


Expand Down Expand Up @@ -181,28 +177,25 @@ def forward(ctx, x, w1, w2, topo, num_input_bits, num_remat_bits, activation_fn)
x_q, x_scales = turbo.quantize_signed(x, num_bits=num_input_bits)
input_save_args = (x_q, x_scales)

# Activation Function.
# Activation function.
if num_remat_bits == -1:
if activation_fn is not None:
act_fn_out = act_fn(sdd_out, activation_fn)
else:
act_fn_out = gelu.gelu(sdd_out)
activation_fn_out = activation_fn(sdd_out, activation_fn)
input_save_args += (sdd_out.data,)
else:
if activation_fn is not None:
raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).')
# fused GELU into sdd_out buffer while quantizing input
hidden_q, hidden_scales, act_fn_out_data = turbo.quantize_signed(
hidden_q, hidden_scales, activation_fn_out_data = turbo.quantize_signed(
sdd_out.data, num_bits=num_remat_bits,
op=turbo.ElemwiseOps.GELU_FORWARD, x_forward=sdd_out.data)
act_fn_out = sdd_out
activation_fn_out = sdd_out
input_save_args += (hidden_q, hidden_scales)

# Layer 1: x @ w2.
dsd_out = stk.ops.dsd(act_fn_out, w2)
dsd_out = stk.ops.dsd(activation_fn_out, w2)

# NOTE: Save the input to the layer and the gelu input for
# gradient computation. We'll re-compute the gelu forward
# NOTE: Save the input to the layer and the activation_fn input for
# gradient computation. We'll re-compute the activation_fn forward
# pass in the backward pass to avoid materializing another
# intermediate.
ctx.shape = topo.shape
Expand Down Expand Up @@ -244,49 +237,43 @@ def backward(ctx, ddsd_out):
activation_fn = ctx.activation_fn
if ctx.num_remat_bits == -1:
sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
if activation_fn is not None:
act_fn_out, act_grad_fn = act_fn(sdd_out, activation_fn, return_grad_fn=True)
else:
act_fn_out = gelu.gelu(sdd_out)
activation_fn_out, act_grad_fn = activation_fn(sdd_out, activation_fn, return_grad_fn=True)
else:
if activation_fn is not None:
raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).')
act_fn_out_tensor = turbo.dequantize_signed(
activation_fn_out_tensor = turbo.dequantize_signed(
hidden_q, hidden_scales, num_bits=ctx.num_remat_bits,
op=turbo.ElemwiseOps.GELU_FORWARD,
out_shape=ctx.sdd_out_shape, out_dtype=dtype)
act_fn_out = stk.Matrix(ctx.shape, act_fn_out_tensor, *topo_tensors)
activation_fn_out = stk.Matrix(ctx.shape, activation_fn_out_tensor, *topo_tensors)

# Compute dw2 with recomputed act_fn output.
dw2 = stk.ops.dsd(act_fn_out.t(), ddsd_out)
# Compute dw2 with recomputed activation_fn output.
dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)

# Compute dact_fn_out.
# Compute dactivation_fn_out.
#
# NOTE: We reuse the act_fn_out allocation.
dact_fn_out = act_fn_out
# NOTE: We reuse the activation_fn_out allocation.
dactivation_fn_out = activation_fn_out
stk.backend.triton_kernels.sdd(
ddsd_out, w2.t(),
dact_fn_out.shape,
dact_fn_out.data,
dact_fn_out.offsets,
dact_fn_out.row_indices,
dact_fn_out.column_indices)
dactivation_fn_out.shape,
dactivation_fn_out.data,
dactivation_fn_out.offsets,
dactivation_fn_out.row_indices,
dactivation_fn_out.column_indices)

# Compute dsdd_out.
#
# NOTE: This reuses the dact_fn_out allocation.
# NOTE: This reuses the dactivation_fn_out allocation.
if ctx.num_remat_bits == -1:
if activation_fn is not None:
act_grad_fn(dact_fn_out.data)
dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
else:
dsdd_out = gelu.gelu_backward_(dact_fn_out, sdd_out)
act_grad_fn(dactivation_fn_out.data)
dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
else:
# confusingly, x_out is interpreted as the gradient to overwrite
# in-place when the elemwise op is a backwards op
ddsd_out_tensor = turbo.dequantize_signed(
hidden_q, hidden_scales, num_bits=ctx.num_remat_bits,
op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dact_fn_out.data)
op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dactivation_fn_out.data)
dsdd_out = stk.Matrix(ctx.shape, ddsd_out_tensor, *topo_tensors)

# rematerialize MLP input now that we need it
Expand Down Expand Up @@ -381,11 +368,8 @@ def parallel_forward(self, x, topo):

# Compute the MLP.
x = wp.sdd_nt(x, w1, topo, group)
if self.args.activation_fn is not None:
act_fn_out = act_fn(x, self.args.activation_fn)
else:
act_fn_out = gelu.gelu(x)
return wp.dsd_nn(act_fn_out, w2, group)
activation_fn_out = activation_fn(x, self.args.activation_fn)
return wp.dsd_nn(activation_fn_out, w2, group)

def forward(self, x, topo):
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
Expand All @@ -398,11 +382,8 @@ def forward(self, x, topo):

# Compute the MLP.
x = stk.ops.sdd(x, w1.t(), topo)
if self.args.activation_fn is not None:
act_fn_out = act_fn(x, self.args.activation_fn)
else:
act_fn_out = gelu.gelu(x)
return stk.ops.dsd(act_fn_out, w2)
activation_fn_out = activation_fn(x, self.args.activation_fn)
return stk.ops.dsd(activation_fn_out, w2)


class MemoryOptimizedGroupedMLP(torch.autograd.Function):
Expand All @@ -427,26 +408,23 @@ def forward(ctx, x, w1, w2, batch_sizes, num_input_bits, num_remat_bits, activat

# GeLU.
if num_remat_bits == -1:
if activation_fn is not None:
act_fn_out = activation_fn(sdd_out)
else:
act_fn_out = F.gelu(sdd_out, approximate="tanh")
activation_fn_out = activation_fn(sdd_out)
input_save_args += (sdd_out,)
else:
if activation_fn is not None:
raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).')
# Fused GELU into sdd_out buffer while quantizing input
hidden_q, hidden_scales, act_fn_out_data = turbo.quantize_signed(
hidden_q, hidden_scales, activation_fn_out_data = turbo.quantize_signed(
sdd_out, num_bits=num_remat_bits,
op=turbo.ElemwiseOps.GELU_FORWARD, x_forward=sdd_out)
act_fn_out = sdd_out
activation_fn_out = sdd_out
input_save_args += (hidden_q, hidden_scales)

# Layer 1: x @ w2.
dsd_out = gg.backend.gmm(act_fn_out, w2, batch_sizes)
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)

# NOTE: Save the input to the layer and the act_fn input for
# gradient computation. We'll re-compute the act_fn forward
# NOTE: Save the input to the layer and the activation_fn input for
# gradient computation. We'll re-compute the activation_fn forward
# pass in the backward pass to avoid materializing another
# intermediate.
ctx.num_input_bits = num_input_bits
Expand Down Expand Up @@ -485,51 +463,45 @@ def backward(ctx, ddsd_out):
else:
hidden_q, hidden_scales = saved_tensors[-2:]

# Rematerialize act_fn output.
# Rematerialize activation_fn output.
activation_fn = ctx.activation_fn
if ctx.num_remat_bits == -1:
act_grad_fn = None
if activation_fn is not None:
with torch.set_grad_enabled(True):
sdd_out.requires_grad = True
act_fn_out = activation_fn(sdd_out)
act_grad_fn = act_fn_out.backward
else:
act_fn_out = F.gelu(sdd_out, approximate="tanh")
with torch.set_grad_enabled(True):
sdd_out.requires_grad = True
activation_fn_out = activation_fn(sdd_out)
act_grad_fn = activation_fn_out.backward
else:
if activation_fn is not None:
raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).')
act_fn_out = turbo.dequantize_signed(
activation_fn_out = turbo.dequantize_signed(
hidden_q, hidden_scales, num_bits=ctx.num_remat_bits,
op=turbo.ElemwiseOps.GELU_FORWARD,
out_shape=ctx.sdd_out_shape, out_dtype=dtype)

# Compute dw2 with recomputed act_fn output.
# Compute dw2 with recomputed activation_fn output.
dw2 = gg.backend.gmm(
act_fn_out, ddsd_out, batch_sizes, trans_a=True)
activation_fn_out, ddsd_out, batch_sizes, trans_a=True)

# Compute dact_fn_out.
# Compute dactivation_fn_out.
#
# NOTE: We reuse the act_fn_out allocation.
dact_fn_out = act_fn_out
# NOTE: We reuse the activation_fn_out allocation.
dactivation_fn_out = activation_fn_out
gg.backend.gmm(
ddsd_out, w2, batch_sizes, trans_b=True, c=dact_fn_out)
ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out)

# Compute dsdd_out.
#
# NOTE: This reuses the dact_fn_out allocation.
# NOTE: This reuses the dactivation_fn_out allocation.
if ctx.num_remat_bits == -1:
if act_grad_fn is not None:
act_grad_fn(dact_fn_out)
dsdd_out = sdd_out.grad
else:
dsdd_out = gelu.gelu_backward_(dact_fn_out, sdd_out)
act_grad_fn(dactivation_fn_out)
dsdd_out = sdd_out.grad
else:
# confusingly, x_out is interpreted as the gradient to overwrite
# in-place when the elemwise op is a backwards op
dsdd_out = turbo.dequantize_signed(
hidden_q, hidden_scales, num_bits=ctx.num_remat_bits,
op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dact_fn_out.data)
op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dactivation_fn_out.data)

# rematerialize MLP input now that we need it
if ctx.num_input_bits != -1:
Expand Down Expand Up @@ -574,8 +546,5 @@ def forward(self, x, tokens_per_expert):

# Compute the MLP.
x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
if self.args.activation_fn is not None:
x = self.args.activation_fn(x)
else:
x = F.gelu(x, approximate="tanh")
x = self.args.activation_fn(x)
return gg.ops.gmm(x, w2, batch_sizes)

0 comments on commit 38cf337

Please sign in to comment.