Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement self-attention and update utils #9

Merged
merged 3 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NOTICE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ Our dependencies are (see [requirements.txt](requirements.txt)):
| Component | About | License |
| --- | --- | --- |
| docstring-parser | Parse Python docstrings | MIT |
| einops | Deep learning operations reinvented (for pytorch, tensorflow, jax and others) | MIT |
| numpy | Array processing library | BSD 3-Clause |
| poptorch-experimental-addons | A collection of addons to [PopTorch](https://github.com/graphcore/poptorch), with general utility | MIT |
| scipy | An open-source software for mathematics, science, and engineering | BSD 3-Clause |

We also use additional Python dependencies for development/testing (see [requirements-dev.txt](requirements-dev.txt)).

Expand Down
39 changes: 39 additions & 0 deletions examples/scale_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

from unit_scaling.modules import MHSA, MLP, Linear
from unit_scaling.utils import analyse_module

print("=== Unit-scaled Linear ===\n")

batch_size = 2**8
hidden_size = 2**10
out_size = 2**10
input = torch.randn(batch_size, hidden_size).requires_grad_()
backward = torch.randn(batch_size, out_size)

annotated_code = analyse_module(
Linear(hidden_size, out_size, bias=False), input, backward
)
print(annotated_code)

print("=== Unit-scaled MLP ===\n")

batch_size = 2**8
hidden_size = 2**10
input = torch.randn(batch_size, hidden_size).requires_grad_()
backward = torch.randn(batch_size, hidden_size)

annotated_code = analyse_module(MLP(hidden_size), input, backward)
print(annotated_code)

print("=== Unit-scaled MHSA ===\n")

batch_size = 2**8
seq_len = 2**6
hidden_size = 2**6
heads = 4
input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_()
backward = torch.randn(batch_size, seq_len, hidden_size)

annotated_code = analyse_module(MHSA(hidden_size, heads), input, backward)
print(annotated_code)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
docstring-parser
einops
numpy
poptorch-experimental-addons @ git+https://github.com/graphcore-research/poptorch-experimental-addons@14886d2285c3e45b0eadf4d719dae87d5f28b109
scipy
torch
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ ignore_missing_imports = True
[mypy-poptorch_experimental_addons.*]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

# As torch.fx doesn't explicitly export many of its useful modules.
[mypy-torch.fx]
implicit_reexport = True
Expand Down
48 changes: 42 additions & 6 deletions unit_scaling/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

"""Common scale-constraints used in unit-scaled operations."""

from math import prod
from typing import Callable

import numpy as np
from scipy import stats

BinaryConstraint = Callable[[float, float], float]
TernaryConstraint = Callable[[float, float, float], float]
VariadicConstraint = Callable[..., float]


def gmean(*scales: float) -> float:
Expand All @@ -19,7 +19,7 @@ def gmean(*scales: float) -> float:
Returns:
float: the geometric mean.
"""
return stats.gmean(scales) # type: ignore
return float(prod(scales) ** (1 / len(scales)))


def hmean(*scales: float) -> float:
Expand All @@ -31,7 +31,7 @@ def hmean(*scales: float) -> float:
Returns:
float: the harmonic mean.
"""
return stats.hmean(scales) # type: ignore
return float(1 / (sum(1 / s for s in scales) / len(scales)))


def amean(*scales: float) -> float:
Expand All @@ -43,7 +43,7 @@ def amean(*scales: float) -> float:
Returns:
float: the arithmetic mean.
"""
return float(np.mean(scales))
return float(sum(scales) / len(scales))


def to_output_scale(output_scale: float, *grad_input_scale: float) -> float:
Expand Down Expand Up @@ -73,3 +73,39 @@ def to_grad_input_scale(output_scale: float, grad_input_scale: float) -> float:
float: equal to `grad_input_scale`
"""
return grad_input_scale


def to_left_grad_scale(
output_scale: float, left_grad_scale: float, right_grad_scale: float
) -> float:
"""Assumes three provided scales:
`(output_scale, left_grad_scale, right_grad_scale)`. Selects only `left_grad_scale`
as the chosen scaling factor.

Args:
output_scale (float): the scale of the op's output
left_grad_scale (float): the scale of the op's left input gradient
right_grad_scale (float): the scale of the op's right input gradient

Returns:
float: equal to `left_grad_scale`
"""
return left_grad_scale


def to_right_grad_scale(
output_scale: float, left_grad_scale: float, right_grad_scale: float
) -> float:
"""Assumes three provided scales:
`(output_scale, left_grad_scale, right_grad_scale)`. Selects only `right_grad_scale`
as the chosen scaling factor.

Args:
output_scale (float): the scale of the op's output
left_grad_scale (float): the scale of the op's left input gradient
right_grad_scale (float): the scale of the op's right input gradient

Returns:
float: equal to `right_grad_scale`
"""
return right_grad_scale
17 changes: 17 additions & 0 deletions unit_scaling/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def docstring_from(
def format_docstring(*args: str) -> Callable[[T], T]:
"""Returns a decorator that applies `cls.__doc__.format(*args)` to the target class.

Args:
args: (*str): the arguments to be passed to the docstrings `.format()` method.

Returns:
Callable[[Type], Type]: a decorator to format the docstring.
"""
Expand All @@ -115,3 +118,17 @@ def f(cls: T) -> T:
" 'constrained' scale (usuall necessary for valid gradients). If `None` is"
" provided, no constraint will be applied. Defaults to `gmean`."
)

ternary_constraint_docstring = (
"constraint (Optional[Callable[[float, float, float], float]], optional): function"
" which takes `output_scale`, `left_grad_scale` & `right_grad_scale` (in that"
" order) and returns a single 'constrained' scale (usuall necessary for valid"
" gradients). If `None` is provided, no constraint will be applied. Defaults to"
" `gmean`."
)

variadic_constraint_docstring = (
"constraint (Optional[Callable[..., float]], optional): function"
" which takes any number of input scales and returns a single 'constrained' scale."
" If `None` is provided, no constraint will be applied. Defaults to `gmean`."
)
98 changes: 91 additions & 7 deletions unit_scaling/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

"""Unit-scaled versions of common `torch.nn.functional` functions."""

from typing import Any, Callable, Optional
import inspect
import sys
from typing import Any, Callable, List, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor

from .constraints import BinaryConstraint, gmean
from .docs import binary_constraint_docstring, docstring_from, format_docstring
from torch import Tensor, fx

from .constraints import BinaryConstraint, TernaryConstraint, gmean
from .docs import (
binary_constraint_docstring,
docstring_from,
format_docstring,
ternary_constraint_docstring,
)
from .scale import scale_bwd, scale_fwd


Expand Down Expand Up @@ -58,6 +65,68 @@ def gelu(
return scaled_gelu(input)


@docstring_from(
F.softmax,
short_description="Applies a **unit-scaled** softmax function.",
add_args=[binary_constraint_docstring],
)
def softmax(
input: Tensor,
dim: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
constraint: Optional[BinaryConstraint] = gmean,
) -> Tensor:
dim_size = input.shape[dim] if dim is not None else input.numel()
output_scale = dim_size / 1.31
grad_input_scale = dim_size / 1.65
scaled_softmax = scale_elementwise(
F.softmax, output_scale, grad_input_scale, constraint
)
return scaled_softmax(input, dim=dim, dtype=dtype)


@docstring_from(
F.dropout, short_description="Applies a **unit-scaled** dropout function."
)
def dropout(
input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False
) -> Tensor:
output_scale = grad_input_scale = (1 - p) ** 0.5
scaled_dropout = scale_elementwise(
F.dropout, output_scale, grad_input_scale, constraint=None
)
return scaled_dropout(input, p, training, inplace)


@docstring_from(
torch.matmul,
short_description="A **unit-scaled** matrix product of two tensors.",
add_args=[ternary_constraint_docstring],
)
def matmul(
left: Tensor,
right: Tensor,
constraint: Optional[TernaryConstraint] = gmean,
) -> Tensor:
left_size = left.shape[-2]
inner_size = left.shape[-1]
right_size = right.shape[-1]

output_scale = inner_size**-0.5
left_grad_scale = right_size**-0.5
right_grad_scale = left_size**-0.5

if constraint:
output_scale = left_grad_scale = right_grad_scale = constraint(
output_scale, left_grad_scale, right_grad_scale
)

left = scale_bwd(left, left_grad_scale)
right = scale_bwd(right, right_grad_scale)
output = torch.matmul(left, right)
return scale_fwd(output, output_scale)


@docstring_from(
F.linear,
short_description="Applies a **unit-scaled** linear transformation.",
Expand All @@ -70,7 +139,7 @@ def linear(
constraint: Optional[BinaryConstraint] = gmean,
) -> Tensor:
fan_out, fan_in = weight.shape
batch_size = int(np.prod(input.shape[:-1]))
batch_size = input.numel() // fan_in

output_scale = fan_in**-0.5
grad_input_scale = fan_out**-0.5
Expand All @@ -83,3 +152,18 @@ def linear(
bias = scale_bwd(bias, grad_bias_scale) if bias is not None else None
output = F.linear(input, weight, bias)
return scale_fwd(output, output_scale)


# Wrap the public functions in this module so that fx tracing doesn't recurse
# into them
def _get_public_fns() -> List[str]:
fns = []
module = sys.modules[__name__]
for name, obj in inspect.getmembers(module):
if inspect.isfunction(obj) and not name.startswith("_"):
fns.append(name)
return fns


for f in _get_public_fns():
fx.wrap(f)
Loading