From 9ea2b939bcf2a851044b3ca8bb7ccb959efa9bf5 Mon Sep 17 00:00:00 2001 From: Charlie Blake Date: Thu, 27 Apr 2023 13:45:04 +0000 Subject: [PATCH 1/3] Add embedding and cross entropy loss --- unit_scaling/functional.py | 64 +++++++++++++++++++++++++++ unit_scaling/modules.py | 37 ++++++++++++++++ unit_scaling/tests/test_functional.py | 36 ++++++++++++++- unit_scaling/tests/test_modules.py | 27 +++++++++++ 4 files changed, 163 insertions(+), 1 deletion(-) diff --git a/unit_scaling/functional.py b/unit_scaling/functional.py index 5f44e21..fffe939 100644 --- a/unit_scaling/functional.py +++ b/unit_scaling/functional.py @@ -219,3 +219,67 @@ def residual_add(residual: Tensor, skip: Tensor, tau: float = 0.2) -> Tensor: residual = scale_fwd(residual, tau**0.5) skip = scale_fwd(skip, (1 - tau) ** 0.5) return residual + skip + + +@docstring_from( + F.embedding, + short_description=( + "A **unit-scaled** lookup table that looks up embeddings in a fixed dictionary" + "and size." + ), +) +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + batch_size = prod(input.shape) + weight = scale_bwd(weight, (weight.shape[0] / batch_size) ** 0.5) + return F.embedding( + input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse + ) + + +@docstring_from( + F.cross_entropy, + short_description=( + "Computes a **unit-scaled** the cross entropy loss between input logits and" + " target." + ), +) +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, +) -> Tensor: + if len(input.shape) == 2: + batch_size, vocab_size = input.shape + else: + assert ( + len(input.shape) == 1 + ), "Input must be (vocab_size) or (batch_size, vocab_size)" + batch_size, vocab_size = 1, input.shape[0] + input = scale_bwd(input, vocab_size / (vocab_size - 1) ** 0.5) + loss = F.cross_entropy( + input, + target, + weight, + size_average, + ignore_index, + reduce, + reduction="sum", + label_smoothing=label_smoothing, + ) + if reduction == "mean": + return scale_fwd(loss, 1 / batch_size) + assert reduction == "sum", "cross_entropy reduction must be either 'sum' or 'mean'." + return loss diff --git a/unit_scaling/modules.py b/unit_scaling/modules.py index 0ba8986..5eae4af 100644 --- a/unit_scaling/modules.py +++ b/unit_scaling/modules.py @@ -110,6 +110,43 @@ def forward(self, input: Tensor) -> Tensor: ) +@inherit_docstring( + short_description=( + "A **unit-scaled** lookup table that looks up embeddings in a fixed dictionary" + " and size." + ) +) +class Embedding(nn.Embedding): + def forward(self, input: Tensor) -> Tensor: + return U.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + +@inherit_docstring( + short_description=( + "Computes a **unit-scaled** the cross entropy loss between input logits and" + " target." + ) +) +class CrossEntropyLoss(nn.CrossEntropyLoss): + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return U.cross_entropy( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing, + ) + + @format_docstring(binary_constraint_docstring) class MLP(nn.Module): """A **unit-scaled** implementation of an MLP layer. diff --git a/unit_scaling/tests/test_functional.py b/unit_scaling/tests/test_functional.py index 3426e7d..550b248 100644 --- a/unit_scaling/tests/test_functional.py +++ b/unit_scaling/tests/test_functional.py @@ -2,7 +2,8 @@ import pytest -from torch import Tensor, zeros +import torch.nn.functional as F +from torch import Tensor, randint, zeros from ..constraints import ( gmean, @@ -12,7 +13,9 @@ to_right_grad_scale, ) from ..functional import ( + cross_entropy, dropout, + embedding, gelu, layer_norm, linear, @@ -279,3 +282,34 @@ def test_residual() -> None: unit_backward(output) assert_unit_scaled(residual, output, residual.grad, skip.grad, input.grad) + + +# --- test embedding() --- + + +def test_embedding() -> None: + batch_sz, seq_len, embedding_dim, num_embeddings = 2**4, 2**5, 2**6, 2**12 + input_idxs = randint(low=0, high=2**12, size=(batch_sz, seq_len)) + embedding_table = unit_normal(num_embeddings, embedding_dim) + output = embedding(input_idxs, embedding_table) + unit_backward(output) + + assert_unit_scaled(output, embedding_table.grad) + + +# --- test cross_entropy() --- + + +def test_cross_entropy() -> None: + num_tokens, vocab_sz = 2**12, 2**8 + for reduction in ["mean", "sum"]: + for input_shape in [(vocab_sz,), (num_tokens, vocab_sz)]: + input = unit_normal(*input_shape) + label_size = (input_shape[0],) if len(input_shape) == 2 else () + labels = randint(low=0, high=vocab_sz, size=label_size) + loss = cross_entropy(input, labels, reduction=reduction) + standard_loss = F.cross_entropy(input, labels, reduction=reduction) + loss.backward() # type: ignore [no-untyped-call] + + assert loss == standard_loss + assert_unit_scaled(input.grad) diff --git a/unit_scaling/tests/test_modules.py b/unit_scaling/tests/test_modules.py index 72f2748..521fdfe 100644 --- a/unit_scaling/tests/test_modules.py +++ b/unit_scaling/tests/test_modules.py @@ -2,13 +2,16 @@ import pytest import torch +from torch import randint from torch.optim import SGD from ..modules import ( GELU, MHSA, MLP, + CrossEntropyLoss, Dropout, + Embedding, LayerNorm, Linear, Softmax, @@ -89,6 +92,30 @@ def test_layer_norm() -> None: assert_unit_scaled(output, input.grad, model.weight.grad, model.bias.grad) +def test_embedding() -> None: + batch_sz, seq_len, embedding_dim, num_embeddings = 2**4, 2**5, 2**6, 2**12 + input_idxs = randint(low=0, high=2**12, size=(batch_sz, seq_len)) + model = Embedding(num_embeddings, embedding_dim) + output = model(input_idxs) + + assert output.shape == torch.Size([batch_sz, seq_len, embedding_dim]) + + unit_backward(output) + + assert_unit_scaled(model.weight.grad) + + +def test_cross_entropy_loss() -> None: + num_tokens, vocab_sz = 2**12, 2**8 + input = unit_normal(num_tokens, vocab_sz) + labels = randint(low=0, high=vocab_sz, size=(num_tokens,)) + model = CrossEntropyLoss() + loss = model(input, labels) + loss.backward() + + assert_unit_scaled(input.grad) + + def test_mlp() -> None: input = unit_normal(2**8, 2**10) model = MLP(2**10) From 8c566449049b1d02b530f40a555a4305aa7b561d Mon Sep 17 00:00:00 2001 From: Charlie Blake Date: Thu, 27 Apr 2023 15:51:12 +0000 Subject: [PATCH 2/3] Add full transformer decoder layer --- examples/scale_analysis.py | 20 +++++++++++- unit_scaling/functional.py | 2 +- unit_scaling/modules.py | 49 ++++++++++++++++++++++++++++++ unit_scaling/tests/test_modules.py | 46 +++++++++++++++++++++++----- unit_scaling/utils.py | 28 +++++++++-------- 5 files changed, 123 insertions(+), 22 deletions(-) diff --git a/examples/scale_analysis.py b/examples/scale_analysis.py index 94eec08..3e88bb3 100644 --- a/examples/scale_analysis.py +++ b/examples/scale_analysis.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import torch -from unit_scaling.modules import MHSA, MLP, Linear, TransformerLayer +from unit_scaling.modules import MHSA, MLP, Linear, TransformerDecoder, TransformerLayer from unit_scaling.utils import analyse_module print("=== Unit-scaled Linear ===\n") @@ -48,3 +48,21 @@ annotated_code = analyse_module(TransformerLayer(hidden_size, heads), input, backward) print(annotated_code) + +print("=== Unit-scaled Full Transformer Decoder ===\n") + +batch_size = 2**8 +seq_len = 2**6 +hidden_size = 2**6 +vocab_size = 2**12 +layers = 2 +heads = 4 + +input_idxs = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) +labels = torch.roll(input_idxs, -1, 1) # Note: doesn't handle final index correctly +backward = torch.randn(batch_size, seq_len, hidden_size) + +annotated_code = analyse_module( + TransformerDecoder(hidden_size, vocab_size, layers, heads), (input_idxs, labels) +) +print(annotated_code) diff --git a/unit_scaling/functional.py b/unit_scaling/functional.py index fffe939..686bd27 100644 --- a/unit_scaling/functional.py +++ b/unit_scaling/functional.py @@ -266,7 +266,7 @@ def cross_entropy( else: assert ( len(input.shape) == 1 - ), "Input must be (vocab_size) or (batch_size, vocab_size)" + ), "cross_entropy input shape must be (vocab_size,) or (batch_size, vocab_size)" batch_size, vocab_size = 1, input.shape[0] input = scale_bwd(input, vocab_size / (vocab_size - 1) ** 0.5) loss = F.cross_entropy( diff --git a/unit_scaling/modules.py b/unit_scaling/modules.py index 5eae4af..e0f2174 100644 --- a/unit_scaling/modules.py +++ b/unit_scaling/modules.py @@ -263,3 +263,52 @@ def forward(self, input: Tensor) -> Tensor: input = self.mlp(input) input = U.dropout(input, self.dropout_p, self.training) return U.residual_add(input, skip, self.tau) + + +class TransformerDecoder(nn.Module): + """A **unit-scaled** implementation of a decoder-type transformer. + + Note: this class is currently just for demonstrating scaling and lacks key + functionality (e.g. causal masking, positional embeddings, usage for inference). + + Args: + hidden_size (int): _description_ + vocab_size (int): _description_ + layers (int): _description_ + heads (int): _description_ + dropout_p (float, optional): _description_. Defaults to 0.1. + act_fn (nn.Module, optional): _description_. Defaults to GELU(). + tau (float, optional): _description_. Defaults to 0.2. + constraint (Optional[VariadicConstraint], optional): _description_. Defaults to + gmean. + """ + + def __init__( + self, + hidden_size: int, + vocab_size: int, + layers: int, + heads: int, + dropout_p: float = 0.1, + act_fn: nn.Module = GELU(), + tau: float = 0.2, + constraint: Optional[VariadicConstraint] = gmean, + ) -> None: + super().__init__() + self.embedding = Embedding(vocab_size, hidden_size) + self.dropout_p = dropout_p + self.transformer_layers = nn.Sequential( + *( + TransformerLayer(hidden_size, heads, dropout_p, act_fn, tau, constraint) + for _ in range(layers) + ) + ) + self.final_layer_norm = LayerNorm(hidden_size) + + def forward(self, input_ids: Tensor, labels: Tensor) -> Tensor: + input = self.embedding(input_ids) + input = U.dropout(input, self.dropout_p, self.training) + input = self.transformer_layers(input) + input = self.final_layer_norm(input) + input = U.linear(input, self.embedding.weight, bias=None, constraint=None) + return U.cross_entropy(input.flatten(end_dim=-2), labels.flatten()) diff --git a/unit_scaling/tests/test_modules.py b/unit_scaling/tests/test_modules.py index 521fdfe..e22a403 100644 --- a/unit_scaling/tests/test_modules.py +++ b/unit_scaling/tests/test_modules.py @@ -15,6 +15,7 @@ LayerNorm, Linear, Softmax, + TransformerDecoder, TransformerLayer, ) from .helper import ( @@ -136,13 +137,13 @@ def test_mlp() -> None: def test_mhsa() -> None: - b, s, d = 2**8, 2**6, 2**6 - input = unit_normal(b, s, d) - model = MHSA(d, heads=8) + batch_sz, seq_len, hidden_dim = 2**8, 2**6, 2**6 + input = unit_normal(batch_sz, seq_len, hidden_dim) + model = MHSA(hidden_dim, heads=8) output = model(input) assert_unit_scaled(model.linear_qkv.weight, model.linear_o.weight) - assert output.shape == torch.Size([b, s, d]) + assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim]) unit_backward(output) SGD(model.parameters(), lr=1).step() @@ -154,15 +155,44 @@ def test_mhsa() -> None: def test_transformer_layer() -> None: - b, s, d = 2**8, 2**6, 2**6 - input = unit_normal(b, s, d) - model = TransformerLayer(d, heads=8) + batch_sz, seq_len, hidden_dim, heads = 2**8, 2**6, 2**6, 8 + input = unit_normal(batch_sz, seq_len, hidden_dim) + model = TransformerLayer(hidden_dim, heads=heads) output = model(input) - assert output.shape == torch.Size([b, s, d]) + assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim]) unit_backward(output) SGD(model.parameters(), lr=1).step() combined_std = output.std().detach() * input.grad.std() # type: ignore assert combined_std == pytest.approx(1, abs=0.1) + + +def test_transformer_decoder() -> None: + batch_size = 2**8 + seq_len = 2**6 + hidden_size = 2**6 + vocab_size = 2**12 + layers = 2 + heads = 4 + + input_idxs = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + labels = torch.roll(input_idxs, -1, 1) + model = TransformerDecoder(hidden_size, vocab_size, layers, heads) + loss = model(input_idxs, labels) + + assert loss.shape == torch.Size([]) + + loss.backward() + SGD(model.parameters(), lr=1).step() + + for name, p in model.named_parameters(): + if "layer_norm.weight" in name: + threshold = 5.0 + elif "layer_norm.bias" in name: + threshold = 20.0 + else: + threshold = 2.5 + assert p.grad is not None + assert p.grad.std().detach() == pytest.approx(1, rel=threshold), name diff --git a/unit_scaling/utils.py b/unit_scaling/utils.py index 3270e76..cbe8f0a 100644 --- a/unit_scaling/utils.py +++ b/unit_scaling/utils.py @@ -74,7 +74,7 @@ def __init__(self, module: fx.GraphModule): def run_node(self, n: fx.Node) -> Any: out = super().run_node(n) - if isinstance(out, Tensor): + if isinstance(out, Tensor) and out.is_floating_point(): self.scales[n.name] = ScalePair() out = ScaleTracker.apply(out, self.scales[n.name]) return out @@ -100,8 +100,8 @@ def placeholder( def _record_scales( fx_graph_module: fx.GraphModule, - input: Tensor, - backward: Tensor, + inputs: Tuple[Tensor, ...], + backward: Optional[Tensor] = None, ) -> ScaleDict: """Given a `torch.fx.GraphModule`, and dummy tensors to feed into the forward and backward passes, returns a dictionary of the scales (standard deviations) of every @@ -109,15 +109,15 @@ def _record_scales( Args: fx_graph_module (fx.GraphModule): the module to record. - input (Tensor): fed into the forward pass for analysis. - backward (Tensor): fed into the output's `.backward()` method for - analysis. + input (Tuple[Tensor, ...]): fed into the forward pass for analysis. + backward (Tensor, optional): fed into the output's `.backward()` method for + analysis. Defaults to `None`, equivalent to calling plain `.backward()`. Returns: ScaleDict: An ordered dictionary with `ScalePair`s for each intermediate tensor. """ tracking_module = ScaleTrackingInterpreter(fx_graph_module) - out = tracking_module.run(input) + out = tracking_module.run(*inputs) out.backward(backward) return tracking_module.scales @@ -229,8 +229,8 @@ def trace( def analyse_module( module: nn.Module, - input: Tensor, - backward: Tensor, + inputs: Union[Tensor, Tuple[Tensor, ...]], + backward: Optional[Tensor] = None, recurse_modules: bool = True, syntax_highlight: bool = True, autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional), @@ -242,8 +242,10 @@ def analyse_module( Args: module (nn.Module): the module to analyse. - input (Tensor): fed into the forward pass for analysis. - backward (Tensor): fed into the output's `.backward()` method for analysis. + inputs (Union[Tensor, Tuple[Tensor, ...]]): fed into the forward pass for + analysis. + backward (Tensor, optional): fed into the output's `.backward()` method for + analysis. Defaults to `None`, equivalent to calling plain `.backward()`. recurse_modules (bool, optional): toggles recursive behavour. Defaults to True. syntax_highlight (bool, optional): Defaults to True. autowrap_modules (Tuple[ModuleType]): defaults to @@ -296,5 +298,7 @@ def forward(self, x): (-> 1.0, <- 0.236) fx_graph = tracer.trace(module) fx_graph_module = fx.GraphModule(tracer.root, fx_graph) - scales = _record_scales(fx_graph_module, input, backward) + if not isinstance(inputs, tuple): + inputs = (inputs,) + scales = _record_scales(fx_graph_module, inputs, backward) return _annotate(fx_graph_module.code, scales, syntax_highlight=syntax_highlight) From 85074c8e87620ae12909a071a9a9b7e40d46d843 Mon Sep 17 00:00:00 2001 From: Charlie Blake Date: Mon, 1 May 2023 22:18:18 +0000 Subject: [PATCH 3/3] Add unsupported arg option to docs wrapper --- examples/scale_analysis.py | 17 ++-- unit_scaling/constraints.py | 6 +- unit_scaling/docs.py | 107 +++++++++++++++++++++----- unit_scaling/functional.py | 29 ++++--- unit_scaling/modules.py | 57 ++++++++------ unit_scaling/tests/test_docs.py | 73 ++++++++++++++++++ unit_scaling/tests/test_functional.py | 41 +++++++++- unit_scaling/tests/test_modules.py | 29 ++++--- unit_scaling/tests/test_utils.py | 4 +- 9 files changed, 289 insertions(+), 74 deletions(-) create mode 100644 unit_scaling/tests/test_docs.py diff --git a/examples/scale_analysis.py b/examples/scale_analysis.py index 3e88bb3..0798274 100644 --- a/examples/scale_analysis.py +++ b/examples/scale_analysis.py @@ -31,10 +31,11 @@ seq_len = 2**6 hidden_size = 2**6 heads = 4 +dropout_p = 0.1 input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_() backward = torch.randn(batch_size, seq_len, hidden_size) -annotated_code = analyse_module(MHSA(hidden_size, heads), input, backward) +annotated_code = analyse_module(MHSA(hidden_size, heads, dropout_p), input, backward) print(annotated_code) print("=== Unit-scaled Transformer Layer ===\n") @@ -43,10 +44,13 @@ seq_len = 2**6 hidden_size = 2**6 heads = 4 +dropout_p = 0.1 input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_() backward = torch.randn(batch_size, seq_len, hidden_size) -annotated_code = analyse_module(TransformerLayer(hidden_size, heads), input, backward) +annotated_code = analyse_module( + TransformerLayer(hidden_size, heads, dropout_p), input, backward +) print(annotated_code) print("=== Unit-scaled Full Transformer Decoder ===\n") @@ -57,12 +61,15 @@ vocab_size = 2**12 layers = 2 heads = 4 +dropout_p = 0.1 -input_idxs = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) -labels = torch.roll(input_idxs, -1, 1) # Note: doesn't handle final index correctly +seq = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len + 1)) +input_idxs = seq[:, :-1] +labels = torch.roll(seq, -1, 1)[:, 1:] backward = torch.randn(batch_size, seq_len, hidden_size) annotated_code = analyse_module( - TransformerDecoder(hidden_size, vocab_size, layers, heads), (input_idxs, labels) + TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p), + (input_idxs, labels), ) print(annotated_code) diff --git a/unit_scaling/constraints.py b/unit_scaling/constraints.py index 005613e..82ec197 100644 --- a/unit_scaling/constraints.py +++ b/unit_scaling/constraints.py @@ -3,10 +3,12 @@ """Common scale-constraints used in unit-scaled operations.""" from math import pow, prod -from typing import Callable +from typing import Callable, Tuple, Union BinaryConstraint = Callable[[float, float], float] -TernaryConstraint = Callable[[float, float, float], float] +TernaryConstraint = Callable[ + [float, float, float], Union[float, Tuple[float, float, float]] +] VariadicConstraint = Callable[..., float] diff --git a/unit_scaling/docs.py b/unit_scaling/docs.py index 516a0ed..2e9ea7c 100644 --- a/unit_scaling/docs.py +++ b/unit_scaling/docs.py @@ -1,7 +1,9 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from functools import partial -from typing import Any, Callable, List, Optional, Type, TypeVar +import inspect +from functools import wraps +from itertools import zip_longest +from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar from docstring_parser.google import ( DEFAULT_SECTIONS, @@ -14,11 +16,54 @@ T = TypeVar("T") +def _validate( + f: Callable[..., T], unsupported_args: Iterable[str] = {} +) -> Callable[..., T]: + """Wraps the supplied function in a check to ensure its arguments aren't in the + unsupported args list. Unsupported args are by nature optional (i.e. they have + a default value). It is assumed this default is valid, but all other values are + invalid.""" + + argspec = inspect.getfullargspec(f) + + # argspec.defaults is a tuple of default arguments. These may begin at an offset + # relative to rgspec.args due to args without a default. To zip these properly the + # lists are reversed, zipped, and un-reversed, with missing values filled with `...` + rev_args = reversed(argspec.args) + rev_defaults = reversed(argspec.defaults) if argspec.defaults else [] + rev_arg_default_pairs = list(zip_longest(rev_args, rev_defaults, fillvalue=...)) + default_kwargs = dict(reversed(rev_arg_default_pairs)) + + for arg in unsupported_args: + if arg not in default_kwargs: + print(default_kwargs, argspec) + raise ValueError(f"unsupported arg '{arg}' is not valid.") + if default_kwargs[arg] is ...: + raise ValueError(f"unsupported arg '{arg}' has no default value") + + @wraps(f) + def f_new(*args: Any, **kwargs: Any) -> T: + arg_values = dict(zip(argspec.args, args)) + full_kwargs = {**arg_values, **kwargs} + for arg_name, arg_value in full_kwargs.items(): + arg_default_value = default_kwargs[arg_name] + if arg_name in unsupported_args and arg_value != arg_default_value: + raise ValueError( + f"Support for the '{arg_name}' argument has not been implemented" + " for the unit-scaling library. Please remove it or replace it" + " with its default value." + ) + return f(*args, **kwargs) + + return f_new + + def _get_docstring_from_target( source: T, target: Any, short_description: Optional[str] = None, add_args: Optional[List[str]] = None, + unsupported_args: Iterable[str] = {}, ) -> T: """Takes the docstring from `target`, modifies it, and applies it to `source`.""" @@ -30,11 +75,19 @@ def _get_docstring_from_target( parser = GoogleParser(sections=parser_sections) docstring = parser.parse(target.__doc__) docstring.short_description = short_description + + for param in docstring.params: + if param.arg_name in unsupported_args and param.description is not None: + param.description = ( + "**[not supported by unit-scaling]** " + param.description + ) + if add_args: for arg_str in add_args: # Parse the additional args strings and add them to the docstring object - param = parser._build_meta(arg_str, "Args") - docstring.meta.append(param) + param_meta = parser._build_meta(arg_str, "Args") + docstring.meta.append(param_meta) + source.__doc__ = compose(docstring) # docstring object to actual string return source @@ -42,6 +95,7 @@ def _get_docstring_from_target( def inherit_docstring( short_description: Optional[str] = None, add_args: Optional[List[str]] = None, + unsupported_args: Iterable[str] = {}, ) -> Callable[[Type[T]], Type[T]]: """Returns a decorator which causes the wrapped class to inherit its parent docstring, with the specified modifications applied. @@ -51,6 +105,8 @@ def inherit_docstring( description in the parent docstring with the one supplied. Defaults to None. add_args (Optional[List[str]], optional): appends the supplied argument strings to the list of arguments. Defaults to None. + unsupported_args (Iterable[str]): a list of arguments which are not supported. + Documentation is updated and runtime checks added to enforce this. Returns: Callable[[Type], Type]: the decorator used to wrap the child class. @@ -58,21 +114,24 @@ def inherit_docstring( def decorator(cls: Type[T]) -> Type[T]: parent = cls.mro()[1] - return _get_docstring_from_target( + source = _get_docstring_from_target( source=cls, target=parent, short_description=short_description, add_args=add_args, + unsupported_args=unsupported_args, ) + return _validate(source, unsupported_args) # type: ignore return decorator def docstring_from( - target: Any, + target: Callable[..., T], short_description: Optional[str] = None, add_args: Optional[List[str]] = None, -) -> Callable[[T], T]: + unsupported_args: Iterable[str] = {}, +) -> Callable[[Callable[..., T]], Callable[..., T]]: """Returns a decorator which causes the wrapped object to take the docstring from the target object, with the specified modifications applied. @@ -82,19 +141,27 @@ def docstring_from( description in the parent docstring with the one supplied. Defaults to None. add_args (Optional[List[str]], optional): appends the supplied argument strings to the list of arguments. Defaults to None. + unsupported_args (Iterable[str]): a list of arguments which are not supported. + Documentation is updated and runtime checks added to enforce this. Returns: Callable[[Callable], Callable]: the decorator used to wrap the child object. """ - return partial( - _get_docstring_from_target, - target=target, - short_description=short_description, - add_args=add_args, - ) + + def decorator(source: Callable[..., T]) -> Callable[..., T]: + source = _get_docstring_from_target( + source=source, + target=target, + short_description=short_description, + add_args=add_args, + unsupported_args=unsupported_args, + ) + return _validate(source, unsupported_args) + + return decorator -def format_docstring(*args: str) -> Callable[[T], T]: +def format_docstring(*args: str) -> Callable[[Callable[..., T]], Callable[..., T]]: """Returns a decorator that applies `cls.__doc__.format(*args)` to the target class. Args: @@ -120,11 +187,13 @@ def f(cls: T) -> T: ) ternary_constraint_docstring = ( - "constraint (Optional[Callable[[float, float, float], float]], optional): function" - " which takes `output_scale`, `left_grad_scale` & `right_grad_scale` (in that" - " order) and returns a single 'constrained' scale (usually necessary for valid" - " gradients). If `None` is provided, no constraint will be applied. Defaults to" - " `gmean`." + "constraint (Optional[Callable[[float, float, float], Union[float, Tuple[float," + " float, float]]]], optional): function which takes `output_scale`," + "`left_grad_scale` & `right_grad_scale` (in that order) and returns either" + " a) a single 'constrained' scale (often necessary for valid gradients), or b) a" + "tuple of three output scales in the same order of the input scales (it is expected" + "that two or all of these scales are constrained to be the same). If `None` is" + "provided, no constraint will be applied. Defaults to `gmean`." ) variadic_constraint_docstring = ( diff --git a/unit_scaling/functional.py b/unit_scaling/functional.py index 686bd27..d0c420b 100644 --- a/unit_scaling/functional.py +++ b/unit_scaling/functional.py @@ -72,11 +72,11 @@ def gelu( ) def softmax( input: Tensor, - dim: Optional[int] = None, + dim: int, dtype: Optional[torch.dtype] = None, constraint: Optional[BinaryConstraint] = gmean, ) -> Tensor: - dim_size = input.shape[dim] if dim is not None else input.numel() + dim_size = input.shape[dim] # Scale factors determined empirically, assuming unit-scaled & large dim_size output_scale = dim_size / 1.31 grad_input_scale = dim_size / 1.65 @@ -87,7 +87,9 @@ def softmax( @docstring_from( - F.dropout, short_description="Applies a **unit-scaled** dropout function." + F.dropout, + short_description="Applies a **unit-scaled** dropout function.", + unsupported_args=["inplace"], ) def dropout( input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False @@ -118,9 +120,11 @@ def matmul( right_grad_scale = left_size**-0.5 if constraint: - output_scale = left_grad_scale = right_grad_scale = constraint( - output_scale, left_grad_scale, right_grad_scale - ) + scale = constraint(output_scale, left_grad_scale, right_grad_scale) + if isinstance(scale, Sequence): + output_scale, left_grad_scale, right_grad_scale = scale # type: ignore + else: + output_scale = left_grad_scale = right_grad_scale = scale left = scale_bwd(left, left_grad_scale) right = scale_bwd(right, right_grad_scale) @@ -227,6 +231,7 @@ def residual_add(residual: Tensor, skip: Tensor, tau: float = 0.2) -> Tensor: "A **unit-scaled** lookup table that looks up embeddings in a fixed dictionary" "and size." ), + unsupported_args=["scale_grad_by_freq", "sparse"], ) def embedding( input: Tensor, @@ -250,6 +255,7 @@ def embedding( "Computes a **unit-scaled** the cross entropy loss between input logits and" " target." ), + unsupported_args=["weight", "size_average", "reduce", "label_smoothing"], ) def cross_entropy( input: Tensor, @@ -263,11 +269,13 @@ def cross_entropy( ) -> Tensor: if len(input.shape) == 2: batch_size, vocab_size = input.shape - else: - assert ( - len(input.shape) == 1 - ), "cross_entropy input shape must be (vocab_size,) or (batch_size, vocab_size)" + elif len(input.shape) == 1: batch_size, vocab_size = 1, input.shape[0] + else: + assert False, ( + f"cross_entropy input shape is {input.shape}, but should be either" + " (vocab_size,) or (batch_size, vocab_size)" + ) input = scale_bwd(input, vocab_size / (vocab_size - 1) ** 0.5) loss = F.cross_entropy( input, @@ -281,5 +289,4 @@ def cross_entropy( ) if reduction == "mean": return scale_fwd(loss, 1 / batch_size) - assert reduction == "sum", "cross_entropy reduction must be either 'sum' or 'mean'." return loss diff --git a/unit_scaling/modules.py b/unit_scaling/modules.py index e0f2174..50e2eae 100644 --- a/unit_scaling/modules.py +++ b/unit_scaling/modules.py @@ -31,7 +31,7 @@ def __init__( self.constraint = constraint def forward(self, input: Tensor) -> Tensor: - return U.gelu(input, self.constraint) + return U.gelu(input, self.constraint) # type: ignore @inherit_docstring( @@ -51,7 +51,7 @@ def forward(self, input: Tensor) -> Tensor: class Softmax(nn.Softmax): def __init__( self, - dim: Optional[int] = None, + dim: int, constraint: Optional[BinaryConstraint] = gmean, ) -> None: super().__init__(dim=dim) @@ -61,7 +61,10 @@ def forward(self, input: Tensor) -> Tensor: return U.softmax(input, dim=self.dim, constraint=self.constraint) -@inherit_docstring(short_description="A **unit-scaled** implementation of Dropout.") +@inherit_docstring( + short_description="A **unit-scaled** implementation of Dropout.", + unsupported_args=["inplace"], +) class Dropout(nn.Dropout): def __init__(self, p: float = 0.5, inplace: bool = False) -> None: super().__init__(p, inplace) @@ -114,7 +117,8 @@ def forward(self, input: Tensor) -> Tensor: short_description=( "A **unit-scaled** lookup table that looks up embeddings in a fixed dictionary" " and size." - ) + ), + unsupported_args=["scale_grad_by_freq", "sparse"], ) class Embedding(nn.Embedding): def forward(self, input: Tensor) -> Tensor: @@ -133,7 +137,8 @@ def forward(self, input: Tensor) -> Tensor: short_description=( "Computes a **unit-scaled** the cross entropy loss between input logits and" " target." - ) + ), + unsupported_args=["weight", "size_average", "reduce", "label_smoothing"], ) class CrossEntropyLoss(nn.CrossEntropyLoss): def forward(self, input: Tensor, target: Tensor) -> Tensor: @@ -182,11 +187,12 @@ def forward(self, input: Tensor) -> Tensor: class MHSA(nn.Module): """A **unit-scaled** implementation of a multi-head self-attention layer. + Warning: using `constraint=None` here will likely give incorrect gradients. + Args: hidden_size (int): the hidden dimension size of the input. heads (int): the number of attention heads. dropout_p (float, optional): the probability of the post-softmax dropout. - Defaults to 0.1. {0} """ @@ -194,7 +200,7 @@ def __init__( self, hidden_size: int, heads: int, - dropout_p: float = 0.1, + dropout_p: float, constraint: Optional[VariadicConstraint] = gmean, ) -> None: super().__init__() @@ -209,8 +215,8 @@ def __init__( self.constraint = constraint def forward(self, input: Tensor) -> Tensor: - qkv = self.linear_qkv(input) - q, k, v = einops.rearrange(qkv, "b s (d z h) -> z b h s d", h=self.heads, z=3) + q_k_v = self.linear_qkv(input) + q, k, v = einops.rearrange(q_k_v, "b s (z h d) -> z b h s d", h=self.heads, z=3) qk = U.matmul(q, k.transpose(-1, -2), constraint=self.constraint) qk = U.softmax(qk, dim=-1, constraint=self.constraint) qk = U.dropout(qk, self.dropout_p, training=self.training) @@ -219,15 +225,17 @@ def forward(self, input: Tensor) -> Tensor: return self.linear_o(qkv) # type: ignore +@format_docstring(variadic_constraint_docstring) class TransformerLayer(nn.Module): """A **unit-scaled** implementation of a PreNorm (see https://arxiv.org/abs/2002.04745) transformer layer. + Warning: using `constraint=None` here will likely give incorrect gradients. + Args: hidden_size (int): the hidden dimension size of the input. heads (int): the number of attention heads. dropout_p (float, optional): the probability of the post-softmax dropout. - Defaults to 0.1. act_fn (nn.Module): the activation function module. Defaults to `GELU()`. tau (float, optional): the weighting of the residual branch relative to the skip connection. Defaults to 0.2. @@ -238,7 +246,7 @@ def __init__( self, hidden_size: int, heads: int, - dropout_p: float = 0.1, + dropout_p: float, act_fn: nn.Module = GELU(), tau: float = 0.2, constraint: Optional[VariadicConstraint] = gmean, @@ -265,22 +273,25 @@ def forward(self, input: Tensor) -> Tensor: return U.residual_add(input, skip, self.tau) +@format_docstring(variadic_constraint_docstring) class TransformerDecoder(nn.Module): """A **unit-scaled** implementation of a decoder-type transformer. Note: this class is currently just for demonstrating scaling and lacks key functionality (e.g. causal masking, positional embeddings, usage for inference). + Warning: using `constraint=None` here will likely give incorrect gradients. + Args: - hidden_size (int): _description_ - vocab_size (int): _description_ - layers (int): _description_ - heads (int): _description_ - dropout_p (float, optional): _description_. Defaults to 0.1. - act_fn (nn.Module, optional): _description_. Defaults to GELU(). - tau (float, optional): _description_. Defaults to 0.2. - constraint (Optional[VariadicConstraint], optional): _description_. Defaults to - gmean. + hidden_size (int): the hidden dimension size of the input. + vocab_size (int): the number of tokens in the vocabulary. + layers (int): the number of transformer layers. + heads (int): the number of attention heads. + dropout_p (float, optional): the probability of the post-softmax dropout. + act_fn (nn.Module): the activation function module. Defaults to `GELU()`. + tau (float, optional): the weighting of the residual branch relative to the skip + connection. Defaults to 0.2. + {0} """ def __init__( @@ -289,14 +300,14 @@ def __init__( vocab_size: int, layers: int, heads: int, - dropout_p: float = 0.1, + dropout_p: float, act_fn: nn.Module = GELU(), tau: float = 0.2, constraint: Optional[VariadicConstraint] = gmean, ) -> None: super().__init__() self.embedding = Embedding(vocab_size, hidden_size) - self.dropout_p = dropout_p + self.initial_layer_norm = LayerNorm(hidden_size) self.transformer_layers = nn.Sequential( *( TransformerLayer(hidden_size, heads, dropout_p, act_fn, tau, constraint) @@ -307,7 +318,7 @@ def __init__( def forward(self, input_ids: Tensor, labels: Tensor) -> Tensor: input = self.embedding(input_ids) - input = U.dropout(input, self.dropout_p, self.training) + input = self.initial_layer_norm(input) input = self.transformer_layers(input) input = self.final_layer_norm(input) input = U.linear(input, self.embedding.weight, bias=None, constraint=None) diff --git a/unit_scaling/tests/test_docs.py b/unit_scaling/tests/test_docs.py new file mode 100644 index 0000000..dcbecab --- /dev/null +++ b/unit_scaling/tests/test_docs.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import pytest + +from ..docs import _validate + + +def f(a, b: int, c="3", d: float = 4.0) -> str: # type: ignore + return f"{a} {b} {c} {d}" + + +def test_validate_no_args() -> None: + def g() -> int: + return 0 + + valid_g = _validate(g) + assert valid_g() == 0 + + +def test_validate_positional_args() -> None: + # Works with no unsupported args + valid_f = _validate(f) + assert valid_f(None, 2) == "None 2 3 4.0" + assert valid_f(None, 2, "3", 4.5) == "None 2 3 4.5" + + # Works with some unsupported args + valid_f = _validate(f, unsupported_args=["c", "d"]) + + # Works if unsupported args are not present or equal default + assert valid_f(None, 2) == "None 2 3 4.0" + assert valid_f(None, 2, "3", 4.0) == "None 2 3 4.0" + + # Doesn't work if non-default unsupported args provided + with pytest.raises(ValueError) as e: + valid_f(None, 2, "3.4") + assert "argument has not been implemented" in str(e.value) + with pytest.raises(ValueError) as e: + valid_f(None, 2, "3", 4.5) + assert "argument has not been implemented" in str(e.value) + + +def test_validate_positional_kwargs() -> None: + # Works with no unsupported args + valid_f = _validate(f) + assert valid_f(None, 2) == "None 2 3 4.0" + assert valid_f(None, 2, c="3", d=4.5) == "None 2 3 4.5" + + # Works with some unsupported args + valid_f = _validate(f, unsupported_args=["c", "d"]) + + # Works if unsupported args are not present or equal default + assert valid_f(None, 2) == "None 2 3 4.0" + assert valid_f(None, 2, c="3") + + # Doesn't work if non-default unsupported args provided + with pytest.raises(ValueError) as e: + valid_f(None, 2, c="3.4") + assert "argument has not been implemented" in str(e.value) + with pytest.raises(ValueError) as e: + valid_f(None, 2, d=4.5) + assert "argument has not been implemented" in str(e.value) + + +def test_validate_invalid_arg() -> None: + with pytest.raises(ValueError) as e: + _validate(f, unsupported_args=["z"]) + assert "is not valid" in str(e.value) + + +def test_validate_no_default_arg() -> None: + with pytest.raises(ValueError) as e: + _validate(f, unsupported_args=["b"]) + assert "has no default value" in str(e.value) diff --git a/unit_scaling/tests/test_functional.py b/unit_scaling/tests/test_functional.py index 550b248..da401f3 100644 --- a/unit_scaling/tests/test_functional.py +++ b/unit_scaling/tests/test_functional.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Tuple import pytest import torch.nn.functional as F @@ -118,7 +119,7 @@ def test_gelu_scale_for_grad_input() -> None: def test_softmax_no_constraint() -> None: input = unit_normal(2**12) - output = softmax(input, constraint=None) + output = softmax(input, dim=0, constraint=None) unit_backward(output) assert_unit_scaled(output, input.grad) @@ -126,7 +127,7 @@ def test_softmax_no_constraint() -> None: def test_softmax_scale_for_output() -> None: input = unit_normal(2**12) - output = softmax(input, constraint=to_output_scale) + output = softmax(input, dim=0, constraint=to_output_scale) unit_backward(output) assert_unit_scaled(output) @@ -135,7 +136,7 @@ def test_softmax_scale_for_output() -> None: def test_softmax_scale_for_grad_input() -> None: input = unit_normal(2**12) - output = softmax(input, constraint=to_grad_input_scale) + output = softmax(input, dim=0, constraint=to_grad_input_scale) unit_backward(output) assert_unit_scaled(input.grad) @@ -164,6 +165,9 @@ def test_dropout() -> None: assert_unit_scaled(output, input.grad) + with pytest.raises(ValueError): + dropout(unit_normal(2**20), 0.5, inplace=True) + # --- test matmul() --- @@ -207,6 +211,25 @@ def test_matmul_scale_for_grad_right() -> None: assert_not_unit_scaled(output, left.grad) +def test_matmul_custom_constraint() -> None: + def constrain_grad_left( + output_scale: float, left_grad_scale: float, right_grad_scale: float + ) -> Tuple[float, float, float]: + output_scale = left_grad_scale = gmean(output_scale, left_grad_scale) + return output_scale, left_grad_scale, right_grad_scale + + left = unit_normal(2**8, 2**10) + right = unit_normal(2**10, 2**12) + output = matmul(left, right, constraint=constrain_grad_left) + unit_backward(output) + + assert_unit_scaled(right.grad) + assert_not_unit_scaled(output, left.grad) + + combined_out_left_std = output.std().detach() * left.grad.std() # type: ignore + assert combined_out_left_std == pytest.approx(1, abs=0.1) + + # --- test linear() --- @@ -296,6 +319,11 @@ def test_embedding() -> None: assert_unit_scaled(output, embedding_table.grad) + with pytest.raises(ValueError): + embedding(input_idxs, embedding_table, scale_grad_by_freq=True) + with pytest.raises(ValueError): + embedding(input_idxs, embedding_table, sparse=True) + # --- test cross_entropy() --- @@ -313,3 +341,10 @@ def test_cross_entropy() -> None: assert loss == standard_loss assert_unit_scaled(input.grad) + + input = unit_normal(2**12, 2**8) + labels = randint(low=0, high=vocab_sz, size=(num_tokens,)) + with pytest.raises(ValueError): + cross_entropy(input, labels, weight=unit_normal(vocab_sz)) + with pytest.raises(ValueError): + cross_entropy(input, labels, label_smoothing=0.5) diff --git a/unit_scaling/tests/test_modules.py b/unit_scaling/tests/test_modules.py index e22a403..0c46fd0 100644 --- a/unit_scaling/tests/test_modules.py +++ b/unit_scaling/tests/test_modules.py @@ -41,7 +41,7 @@ def test_gelu() -> None: def test_softmax() -> None: input = unit_normal(2**14) - model = Softmax() + model = Softmax(dim=0) output = model(input) unit_backward(output) @@ -60,6 +60,9 @@ def test_dropout() -> None: combined_std = output.std().detach() * input.grad.std() # type: ignore assert combined_std == pytest.approx(1, abs=0.1) + with pytest.raises(ValueError): + Dropout(0.5, inplace=True) + def test_linear() -> None: input = unit_normal(2**8, 2**10) @@ -105,6 +108,11 @@ def test_embedding() -> None: assert_unit_scaled(model.weight.grad) + with pytest.raises(ValueError): + Embedding(num_embeddings, embedding_dim, scale_grad_by_freq=True) + with pytest.raises(ValueError): + Embedding(num_embeddings, embedding_dim, sparse=True) + def test_cross_entropy_loss() -> None: num_tokens, vocab_sz = 2**12, 2**8 @@ -116,6 +124,11 @@ def test_cross_entropy_loss() -> None: assert_unit_scaled(input.grad) + with pytest.raises(ValueError): + CrossEntropyLoss(weight=unit_normal(vocab_sz)) + with pytest.raises(ValueError): + CrossEntropyLoss(label_smoothing=0.5) + def test_mlp() -> None: input = unit_normal(2**8, 2**10) @@ -139,7 +152,7 @@ def test_mlp() -> None: def test_mhsa() -> None: batch_sz, seq_len, hidden_dim = 2**8, 2**6, 2**6 input = unit_normal(batch_sz, seq_len, hidden_dim) - model = MHSA(hidden_dim, heads=8) + model = MHSA(hidden_dim, heads=8, dropout_p=0.1) output = model(input) assert_unit_scaled(model.linear_qkv.weight, model.linear_o.weight) @@ -157,7 +170,7 @@ def test_mhsa() -> None: def test_transformer_layer() -> None: batch_sz, seq_len, hidden_dim, heads = 2**8, 2**6, 2**6, 8 input = unit_normal(batch_sz, seq_len, hidden_dim) - model = TransformerLayer(hidden_dim, heads=heads) + model = TransformerLayer(hidden_dim, heads=heads, dropout_p=0.1) output = model(input) assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim]) @@ -179,7 +192,7 @@ def test_transformer_decoder() -> None: input_idxs = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) labels = torch.roll(input_idxs, -1, 1) - model = TransformerDecoder(hidden_size, vocab_size, layers, heads) + model = TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p=0.1) loss = model(input_idxs, labels) assert loss.shape == torch.Size([]) @@ -188,11 +201,9 @@ def test_transformer_decoder() -> None: SGD(model.parameters(), lr=1).step() for name, p in model.named_parameters(): - if "layer_norm.weight" in name: - threshold = 5.0 - elif "layer_norm.bias" in name: + if "layer_norm.bias" in name: threshold = 20.0 else: - threshold = 2.5 + threshold = 5.0 assert p.grad is not None - assert p.grad.std().detach() == pytest.approx(1, rel=threshold), name + assert 1 / threshold <= p.grad.std().detach() <= threshold, name diff --git a/unit_scaling/tests/test_utils.py b/unit_scaling/tests/test_utils.py index 901fc8c..8281f30 100644 --- a/unit_scaling/tests/test_utils.py +++ b/unit_scaling/tests/test_utils.py @@ -49,7 +49,7 @@ def test_analyse_mhsa() -> None: backward = torch.randn(batch_size, seq_len, hidden_size) annotated_code = analyse_module( - MHSA(hidden_size, heads), input, backward, syntax_highlight=False + MHSA(hidden_size, heads, dropout_p=0.1), input, backward, syntax_highlight=False ) expected_code = """ @@ -57,7 +57,7 @@ def forward(self, input : torch.Tensor) -> torch.Tensor: input_1 = input; (-> 1.0, <- 0.819) linear_qkv_weight = self.linear_qkv.weight; (-> 1.01, <- 0.681) linear = U.linear(input_1, linear_qkv_weight, None, gmean); (-> 0.766, <- 0.631) - rearrange = einops_einops_rearrange(linear, 'b s (d z h) -> z b h s d', h = 4, z = 3); (-> 0.766, <- 0.631) + rearrange = einops_einops_rearrange(linear, 'b s (z h d) -> z b h s d', h = 4, z = 3); (-> 0.766, <- 0.631) getitem = rearrange[0]; (-> 0.774, <- 0.463) getitem_1 = rearrange[1]; (-> 0.773, <- 0.34) getitem_2 = rearrange[2]; (-> 0.752, <- 0.929)