Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full transformer #11

Merged
merged 4 commits into from
May 4, 2023
Merged

Full transformer #11

merged 4 commits into from
May 4, 2023

Conversation

thecharlieblake
Copy link
Contributor

No description provided.

@thecharlieblake thecharlieblake force-pushed the full-transformer branch 2 times, most recently from 56b2400 to d3c232d Compare April 27, 2023 16:11
Copy link
Collaborator

@DouglasOrr DouglasOrr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, all slotting together neatly!

LGTM (pre-approved); a bunch of minor comments / things to think about.

sparse: bool = False,
) -> Tensor:
batch_size = prod(input.shape)
weight = scale_bwd(weight, (weight.shape[0] / batch_size) ** 0.5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the right rule, but it does sometimes feel a bit risky! Perhaps in the case where it's risky (e.g. knowledge graph vocab_size=1M), the user really needs to set sparse=True and we should also do something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point, I'd forgotten about this issue. I'm tempted to say that for now we shouldn't support sparse=true and add that to our todo list for some point down the line. For huge vocab or tiny batch we may have an issue.

Having said that, even for 2**20 vocab and 2**8 batch the scaling factor is only 64 which isn't too bad. And in the sparse setting if you don't have that then maybe you just get dominated by the non-sparse decoder grads in the long-run, unless you have this slightly crazy scaling for the encoder grads?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, feels like a long time to reply... 👍 sounds reasonable.

functionality (e.g. causal masking, positional embeddings, usage for inference).

Args:
hidden_size (int): _description_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these _description_ placeholders? Can't see any of your magic @s to fill them in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just me forgetting to write the docs 🏅🐟

vocab_size (int): _description_
layers (int): _description_
heads (int): _description_
dropout_p (float, optional): _description_. Defaults to 0.1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had wondered if default-on dropout felt a bit weird. (I remember asking in #9. I wonder, did you see that / did the github auto-collapsing thing get in the way?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no. I missed 9 comments there because of auto-collapsing - lesson learned! (bad UI? bad user?) I'll address them here


def forward(self, input_ids: Tensor, labels: Tensor) -> Tensor:
input = self.embedding(input_ids)
input = U.dropout(input, self.dropout_p, self.training)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually put a layer_norm here and no dropout, but OPT and LLaMA seem to just use the embedding directly (LLaMA might have dropout, not visible in inference).

I guess this seems reasonable for now, and I presume we've got some opportunity to tweak the default before loads of people depend on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your version, it has a symmetry to it: embed->LN->transformer body->LN->un-embed

else:
threshold = 2.5
assert p.grad is not None
assert p.grad.std().detach() == pytest.approx(1, rel=threshold), name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test would check that std() is between [-19, 21] in the case of layer_norm.bias. Perhaps [1/20, 20] would be a better range?

@thecharlieblake
Copy link
Contributor Author

Updates based on review feedback (including a new feature in the docs DSL thing!) Changes here - 85074c8


for arg in unsupported_args:
if arg not in default_kwargs:
print(default_kwargs, argspec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this print deliberately retained?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops!

if isinstance(scale, Sequence):
output_scale, left_grad_scale, right_grad_scale = scale # type: ignore
else:
output_scale = left_grad_scale = right_grad_scale = scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this block should go to constraints as apply_ternary or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree what I have is a bit ugly, but I'm also a bit concerned that another level of indirection might be hard for new users to follow. Might leave this as-is for now...

Copy link
Collaborator

@DouglasOrr DouglasOrr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good! unsupported_arg is v. good.

@thecharlieblake thecharlieblake merged commit 6f87e8d into transformer-layer May 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants