-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full transformer #11
Full transformer #11
Conversation
56b2400
to
d3c232d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, all slotting together neatly!
LGTM (pre-approved); a bunch of minor comments / things to think about.
sparse: bool = False, | ||
) -> Tensor: | ||
batch_size = prod(input.shape) | ||
weight = scale_bwd(weight, (weight.shape[0] / batch_size) ** 0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the right rule, but it does sometimes feel a bit risky! Perhaps in the case where it's risky (e.g. knowledge graph vocab_size=1M
), the user really needs to set sparse=True
and we should also do something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a good point, I'd forgotten about this issue. I'm tempted to say that for now we shouldn't support sparse=true and add that to our todo list for some point down the line. For huge vocab or tiny batch we may have an issue.
Having said that, even for 2**20
vocab and 2**8
batch the scaling factor is only 64 which isn't too bad. And in the sparse setting if you don't have that then maybe you just get dominated by the non-sparse decoder grads in the long-run, unless you have this slightly crazy scaling for the encoder grads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, feels like a long time to reply... 👍 sounds reasonable.
unit_scaling/modules.py
Outdated
functionality (e.g. causal masking, positional embeddings, usage for inference). | ||
|
||
Args: | ||
hidden_size (int): _description_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these _description_
placeholders? Can't see any of your magic @s
to fill them in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just me forgetting to write the docs 🏅🐟
unit_scaling/modules.py
Outdated
vocab_size (int): _description_ | ||
layers (int): _description_ | ||
heads (int): _description_ | ||
dropout_p (float, optional): _description_. Defaults to 0.1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had wondered if default-on dropout felt a bit weird. (I remember asking in #9. I wonder, did you see that / did the github auto-collapsing thing get in the way?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh no. I missed 9 comments there because of auto-collapsing - lesson learned! (bad UI? bad user?) I'll address them here
unit_scaling/modules.py
Outdated
|
||
def forward(self, input_ids: Tensor, labels: Tensor) -> Tensor: | ||
input = self.embedding(input_ids) | ||
input = U.dropout(input, self.dropout_p, self.training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually put a layer_norm
here and no dropout
, but OPT and LLaMA seem to just use the embedding directly (LLaMA might have dropout, not visible in inference).
I guess this seems reasonable for now, and I presume we've got some opportunity to tweak the default before loads of people depend on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like your version, it has a symmetry to it: embed->LN->transformer body->LN->un-embed
unit_scaling/tests/test_modules.py
Outdated
else: | ||
threshold = 2.5 | ||
assert p.grad is not None | ||
assert p.grad.std().detach() == pytest.approx(1, rel=threshold), name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test would check that std() is between [-19, 21]
in the case of layer_norm.bias
. Perhaps [1/20, 20]
would be a better range?
d3c232d
to
85074c8
Compare
Updates based on review feedback (including a new feature in the docs DSL thing!) Changes here - 85074c8 |
|
||
for arg in unsupported_args: | ||
if arg not in default_kwargs: | ||
print(default_kwargs, argspec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this print
deliberately retained?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops!
if isinstance(scale, Sequence): | ||
output_scale, left_grad_scale, right_grad_scale = scale # type: ignore | ||
else: | ||
output_scale = left_grad_scale = right_grad_scale = scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this block should go to constraints
as apply_ternary
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree what I have is a bit ugly, but I'm also a bit concerned that another level of indirection might be hard for new users to follow. Might leave this as-is for now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good! unsupported_arg
is v. good.
No description provided.