-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full transformer #11
Full transformer #11
Changes from all commits
cddcbae
9ea2b93
8c56644
85074c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,11 +72,11 @@ def gelu( | |
) | ||
def softmax( | ||
input: Tensor, | ||
dim: Optional[int] = None, | ||
dim: int, | ||
dtype: Optional[torch.dtype] = None, | ||
constraint: Optional[BinaryConstraint] = gmean, | ||
) -> Tensor: | ||
dim_size = input.shape[dim] if dim is not None else input.numel() | ||
dim_size = input.shape[dim] | ||
# Scale factors determined empirically, assuming unit-scaled & large dim_size | ||
output_scale = dim_size / 1.31 | ||
grad_input_scale = dim_size / 1.65 | ||
|
@@ -87,7 +87,9 @@ def softmax( | |
|
||
|
||
@docstring_from( | ||
F.dropout, short_description="Applies a **unit-scaled** dropout function." | ||
F.dropout, | ||
short_description="Applies a **unit-scaled** dropout function.", | ||
unsupported_args=["inplace"], | ||
) | ||
def dropout( | ||
input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False | ||
|
@@ -118,9 +120,11 @@ def matmul( | |
right_grad_scale = left_size**-0.5 | ||
|
||
if constraint: | ||
output_scale = left_grad_scale = right_grad_scale = constraint( | ||
output_scale, left_grad_scale, right_grad_scale | ||
) | ||
scale = constraint(output_scale, left_grad_scale, right_grad_scale) | ||
if isinstance(scale, Sequence): | ||
output_scale, left_grad_scale, right_grad_scale = scale # type: ignore | ||
else: | ||
output_scale = left_grad_scale = right_grad_scale = scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this block should go to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree what I have is a bit ugly, but I'm also a bit concerned that another level of indirection might be hard for new users to follow. Might leave this as-is for now... |
||
|
||
left = scale_bwd(left, left_grad_scale) | ||
right = scale_bwd(right, right_grad_scale) | ||
|
@@ -219,3 +223,70 @@ def residual_add(residual: Tensor, skip: Tensor, tau: float = 0.2) -> Tensor: | |
residual = scale_fwd(residual, tau**0.5) | ||
skip = scale_fwd(skip, (1 - tau) ** 0.5) | ||
return residual + skip | ||
|
||
|
||
@docstring_from( | ||
F.embedding, | ||
short_description=( | ||
"A **unit-scaled** lookup table that looks up embeddings in a fixed dictionary" | ||
"and size." | ||
), | ||
unsupported_args=["scale_grad_by_freq", "sparse"], | ||
) | ||
def embedding( | ||
input: Tensor, | ||
weight: Tensor, | ||
padding_idx: Optional[int] = None, | ||
max_norm: Optional[float] = None, | ||
norm_type: float = 2.0, | ||
scale_grad_by_freq: bool = False, | ||
sparse: bool = False, | ||
DouglasOrr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> Tensor: | ||
batch_size = prod(input.shape) | ||
weight = scale_bwd(weight, (weight.shape[0] / batch_size) ** 0.5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is the right rule, but it does sometimes feel a bit risky! Perhaps in the case where it's risky (e.g. knowledge graph There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's a good point, I'd forgotten about this issue. I'm tempted to say that for now we shouldn't support sparse=true and add that to our todo list for some point down the line. For huge vocab or tiny batch we may have an issue. Having said that, even for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, feels like a long time to reply... 👍 sounds reasonable. |
||
return F.embedding( | ||
input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse | ||
) | ||
|
||
|
||
@docstring_from( | ||
F.cross_entropy, | ||
short_description=( | ||
"Computes a **unit-scaled** the cross entropy loss between input logits and" | ||
" target." | ||
), | ||
unsupported_args=["weight", "size_average", "reduce", "label_smoothing"], | ||
) | ||
def cross_entropy( | ||
input: Tensor, | ||
target: Tensor, | ||
weight: Optional[Tensor] = None, | ||
size_average: Optional[bool] = None, | ||
ignore_index: int = -100, | ||
reduce: Optional[bool] = None, | ||
reduction: str = "mean", | ||
label_smoothing: float = 0.0, | ||
) -> Tensor: | ||
DouglasOrr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(input.shape) == 2: | ||
batch_size, vocab_size = input.shape | ||
elif len(input.shape) == 1: | ||
batch_size, vocab_size = 1, input.shape[0] | ||
thecharlieblake marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
assert False, ( | ||
f"cross_entropy input shape is {input.shape}, but should be either" | ||
" (vocab_size,) or (batch_size, vocab_size)" | ||
) | ||
input = scale_bwd(input, vocab_size / (vocab_size - 1) ** 0.5) | ||
loss = F.cross_entropy( | ||
input, | ||
target, | ||
weight, | ||
size_average, | ||
ignore_index, | ||
reduce, | ||
reduction="sum", | ||
label_smoothing=label_smoothing, | ||
) | ||
if reduction == "mean": | ||
return scale_fwd(loss, 1 / batch_size) | ||
return loss | ||
thecharlieblake marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this
print
deliberately retained?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops!