Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full transformer #11

Merged
merged 4 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions examples/scale_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import torch

from unit_scaling.modules import MHSA, MLP, Linear, TransformerLayer
from unit_scaling.modules import MHSA, MLP, Linear, TransformerDecoder, TransformerLayer
from unit_scaling.utils import analyse_module

print("=== Unit-scaled Linear ===\n")
Expand Down Expand Up @@ -31,10 +31,11 @@
seq_len = 2**6
hidden_size = 2**6
heads = 4
dropout_p = 0.1
input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_()
backward = torch.randn(batch_size, seq_len, hidden_size)

annotated_code = analyse_module(MHSA(hidden_size, heads), input, backward)
annotated_code = analyse_module(MHSA(hidden_size, heads, dropout_p), input, backward)
print(annotated_code)

print("=== Unit-scaled Transformer Layer ===\n")
Expand All @@ -43,8 +44,32 @@
seq_len = 2**6
hidden_size = 2**6
heads = 4
dropout_p = 0.1
input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_()
backward = torch.randn(batch_size, seq_len, hidden_size)

annotated_code = analyse_module(TransformerLayer(hidden_size, heads), input, backward)
annotated_code = analyse_module(
TransformerLayer(hidden_size, heads, dropout_p), input, backward
)
print(annotated_code)

print("=== Unit-scaled Full Transformer Decoder ===\n")

batch_size = 2**8
seq_len = 2**6
hidden_size = 2**6
vocab_size = 2**12
layers = 2
heads = 4
dropout_p = 0.1

seq = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len + 1))
input_idxs = seq[:, :-1]
labels = torch.roll(seq, -1, 1)[:, 1:]
backward = torch.randn(batch_size, seq_len, hidden_size)

annotated_code = analyse_module(
TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p),
(input_idxs, labels),
)
print(annotated_code)
6 changes: 4 additions & 2 deletions unit_scaling/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""Common scale-constraints used in unit-scaled operations."""

from math import pow, prod
from typing import Callable
from typing import Callable, Tuple, Union

BinaryConstraint = Callable[[float, float], float]
TernaryConstraint = Callable[[float, float, float], float]
TernaryConstraint = Callable[
[float, float, float], Union[float, Tuple[float, float, float]]
]
VariadicConstraint = Callable[..., float]


Expand Down
107 changes: 88 additions & 19 deletions unit_scaling/docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import partial
from typing import Any, Callable, List, Optional, Type, TypeVar
import inspect
from functools import wraps
from itertools import zip_longest
from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar

from docstring_parser.google import (
DEFAULT_SECTIONS,
Expand All @@ -14,11 +16,54 @@
T = TypeVar("T")


def _validate(
f: Callable[..., T], unsupported_args: Iterable[str] = {}
) -> Callable[..., T]:
"""Wraps the supplied function in a check to ensure its arguments aren't in the
unsupported args list. Unsupported args are by nature optional (i.e. they have
a default value). It is assumed this default is valid, but all other values are
invalid."""

argspec = inspect.getfullargspec(f)

# argspec.defaults is a tuple of default arguments. These may begin at an offset
# relative to rgspec.args due to args without a default. To zip these properly the
# lists are reversed, zipped, and un-reversed, with missing values filled with `...`
rev_args = reversed(argspec.args)
rev_defaults = reversed(argspec.defaults) if argspec.defaults else []
rev_arg_default_pairs = list(zip_longest(rev_args, rev_defaults, fillvalue=...))
default_kwargs = dict(reversed(rev_arg_default_pairs))

for arg in unsupported_args:
if arg not in default_kwargs:
print(default_kwargs, argspec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this print deliberately retained?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops!

raise ValueError(f"unsupported arg '{arg}' is not valid.")
if default_kwargs[arg] is ...:
raise ValueError(f"unsupported arg '{arg}' has no default value")

@wraps(f)
def f_new(*args: Any, **kwargs: Any) -> T:
arg_values = dict(zip(argspec.args, args))
full_kwargs = {**arg_values, **kwargs}
for arg_name, arg_value in full_kwargs.items():
arg_default_value = default_kwargs[arg_name]
if arg_name in unsupported_args and arg_value != arg_default_value:
raise ValueError(
f"Support for the '{arg_name}' argument has not been implemented"
" for the unit-scaling library. Please remove it or replace it"
" with its default value."
)
return f(*args, **kwargs)

return f_new


def _get_docstring_from_target(
source: T,
target: Any,
short_description: Optional[str] = None,
add_args: Optional[List[str]] = None,
unsupported_args: Iterable[str] = {},
) -> T:
"""Takes the docstring from `target`, modifies it, and applies it to `source`."""

Expand All @@ -30,18 +75,27 @@ def _get_docstring_from_target(
parser = GoogleParser(sections=parser_sections)
docstring = parser.parse(target.__doc__)
docstring.short_description = short_description

for param in docstring.params:
if param.arg_name in unsupported_args and param.description is not None:
param.description = (
"**[not supported by unit-scaling]** " + param.description
)

if add_args:
for arg_str in add_args:
# Parse the additional args strings and add them to the docstring object
param = parser._build_meta(arg_str, "Args")
docstring.meta.append(param)
param_meta = parser._build_meta(arg_str, "Args")
docstring.meta.append(param_meta)

source.__doc__ = compose(docstring) # docstring object to actual string
return source


def inherit_docstring(
short_description: Optional[str] = None,
add_args: Optional[List[str]] = None,
unsupported_args: Iterable[str] = {},
) -> Callable[[Type[T]], Type[T]]:
"""Returns a decorator which causes the wrapped class to inherit its parent
docstring, with the specified modifications applied.
Expand All @@ -51,28 +105,33 @@ def inherit_docstring(
description in the parent docstring with the one supplied. Defaults to None.
add_args (Optional[List[str]], optional): appends the supplied argument strings
to the list of arguments. Defaults to None.
unsupported_args (Iterable[str]): a list of arguments which are not supported.
Documentation is updated and runtime checks added to enforce this.

Returns:
Callable[[Type], Type]: the decorator used to wrap the child class.
"""

def decorator(cls: Type[T]) -> Type[T]:
parent = cls.mro()[1]
return _get_docstring_from_target(
source = _get_docstring_from_target(
source=cls,
target=parent,
short_description=short_description,
add_args=add_args,
unsupported_args=unsupported_args,
)
return _validate(source, unsupported_args) # type: ignore

return decorator


def docstring_from(
target: Any,
target: Callable[..., T],
short_description: Optional[str] = None,
add_args: Optional[List[str]] = None,
) -> Callable[[T], T]:
unsupported_args: Iterable[str] = {},
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Returns a decorator which causes the wrapped object to take the docstring from
the target object, with the specified modifications applied.

Expand All @@ -82,19 +141,27 @@ def docstring_from(
description in the parent docstring with the one supplied. Defaults to None.
add_args (Optional[List[str]], optional): appends the supplied argument strings
to the list of arguments. Defaults to None.
unsupported_args (Iterable[str]): a list of arguments which are not supported.
Documentation is updated and runtime checks added to enforce this.

Returns:
Callable[[Callable], Callable]: the decorator used to wrap the child object.
"""
return partial(
_get_docstring_from_target,
target=target,
short_description=short_description,
add_args=add_args,
)

def decorator(source: Callable[..., T]) -> Callable[..., T]:
source = _get_docstring_from_target(
source=source,
target=target,
short_description=short_description,
add_args=add_args,
unsupported_args=unsupported_args,
)
return _validate(source, unsupported_args)

return decorator


def format_docstring(*args: str) -> Callable[[T], T]:
def format_docstring(*args: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Returns a decorator that applies `cls.__doc__.format(*args)` to the target class.

Args:
Expand All @@ -120,11 +187,13 @@ def f(cls: T) -> T:
)

ternary_constraint_docstring = (
"constraint (Optional[Callable[[float, float, float], float]], optional): function"
" which takes `output_scale`, `left_grad_scale` & `right_grad_scale` (in that"
" order) and returns a single 'constrained' scale (usually necessary for valid"
" gradients). If `None` is provided, no constraint will be applied. Defaults to"
" `gmean`."
"constraint (Optional[Callable[[float, float, float], Union[float, Tuple[float,"
" float, float]]]], optional): function which takes `output_scale`,"
"`left_grad_scale` & `right_grad_scale` (in that order) and returns either"
" a) a single 'constrained' scale (often necessary for valid gradients), or b) a"
"tuple of three output scales in the same order of the input scales (it is expected"
"that two or all of these scales are constrained to be the same). If `None` is"
"provided, no constraint will be applied. Defaults to `gmean`."
)

variadic_constraint_docstring = (
Expand Down
83 changes: 77 additions & 6 deletions unit_scaling/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def gelu(
)
def softmax(
input: Tensor,
dim: Optional[int] = None,
dim: int,
dtype: Optional[torch.dtype] = None,
constraint: Optional[BinaryConstraint] = gmean,
) -> Tensor:
dim_size = input.shape[dim] if dim is not None else input.numel()
dim_size = input.shape[dim]
# Scale factors determined empirically, assuming unit-scaled & large dim_size
output_scale = dim_size / 1.31
grad_input_scale = dim_size / 1.65
Expand All @@ -87,7 +87,9 @@ def softmax(


@docstring_from(
F.dropout, short_description="Applies a **unit-scaled** dropout function."
F.dropout,
short_description="Applies a **unit-scaled** dropout function.",
unsupported_args=["inplace"],
)
def dropout(
input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False
Expand Down Expand Up @@ -118,9 +120,11 @@ def matmul(
right_grad_scale = left_size**-0.5

if constraint:
output_scale = left_grad_scale = right_grad_scale = constraint(
output_scale, left_grad_scale, right_grad_scale
)
scale = constraint(output_scale, left_grad_scale, right_grad_scale)
if isinstance(scale, Sequence):
output_scale, left_grad_scale, right_grad_scale = scale # type: ignore
else:
output_scale = left_grad_scale = right_grad_scale = scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this block should go to constraints as apply_ternary or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree what I have is a bit ugly, but I'm also a bit concerned that another level of indirection might be hard for new users to follow. Might leave this as-is for now...


left = scale_bwd(left, left_grad_scale)
right = scale_bwd(right, right_grad_scale)
Expand Down Expand Up @@ -219,3 +223,70 @@ def residual_add(residual: Tensor, skip: Tensor, tau: float = 0.2) -> Tensor:
residual = scale_fwd(residual, tau**0.5)
skip = scale_fwd(skip, (1 - tau) ** 0.5)
return residual + skip


@docstring_from(
F.embedding,
short_description=(
"A **unit-scaled** lookup table that looks up embeddings in a fixed dictionary"
"and size."
),
unsupported_args=["scale_grad_by_freq", "sparse"],
)
def embedding(
input: Tensor,
weight: Tensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
) -> Tensor:
batch_size = prod(input.shape)
weight = scale_bwd(weight, (weight.shape[0] / batch_size) ** 0.5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the right rule, but it does sometimes feel a bit risky! Perhaps in the case where it's risky (e.g. knowledge graph vocab_size=1M), the user really needs to set sparse=True and we should also do something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point, I'd forgotten about this issue. I'm tempted to say that for now we shouldn't support sparse=true and add that to our todo list for some point down the line. For huge vocab or tiny batch we may have an issue.

Having said that, even for 2**20 vocab and 2**8 batch the scaling factor is only 64 which isn't too bad. And in the sparse setting if you don't have that then maybe you just get dominated by the non-sparse decoder grads in the long-run, unless you have this slightly crazy scaling for the encoder grads?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, feels like a long time to reply... 👍 sounds reasonable.

return F.embedding(
input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
)


@docstring_from(
F.cross_entropy,
short_description=(
"Computes a **unit-scaled** the cross entropy loss between input logits and"
" target."
),
unsupported_args=["weight", "size_average", "reduce", "label_smoothing"],
)
def cross_entropy(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
) -> Tensor:
if len(input.shape) == 2:
batch_size, vocab_size = input.shape
elif len(input.shape) == 1:
batch_size, vocab_size = 1, input.shape[0]
else:
assert False, (
f"cross_entropy input shape is {input.shape}, but should be either"
" (vocab_size,) or (batch_size, vocab_size)"
)
input = scale_bwd(input, vocab_size / (vocab_size - 1) ** 0.5)
loss = F.cross_entropy(
input,
target,
weight,
size_average,
ignore_index,
reduce,
reduction="sum",
label_smoothing=label_smoothing,
)
if reduction == "mean":
return scale_fwd(loss, 1 / batch_size)
return loss
Loading