-
Notifications
You must be signed in to change notification settings - Fork 20
Allow for modifying the scaled_mm compute #144
Conversation
self.emulate = False | ||
# Defines the behavior of the matmul in the forward and backward pass | ||
self.forward_config = ScaledMMConfig() | ||
self.backward_config = ScaledMMConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to configure the two backward gemms separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is somewhat challenging , since as written today we dont have a very clean way of knowing which matmul is which
cc @bdhirsh maybe Im no thinking of something.
We have out = x@W
where x = FLoat8Tensor and W = Float8Tensor.
Since W will not be used in calucalting the gradW you could tag some extra info on the activation float8tensor and since this gets used for backward this should get carried through to backwards calcs.
I think that this would be better as a follow up though since the logic gets spread out over multiple Float8Tensor instances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since as written today we dont have a very clean way of knowing which matmul is which
yeah, this is weird because the config is really per gemm but we have to stick it on a tensor. How about something like
- local: given matmul(A, B), the B matmul (second argument) always overrides A (first argument).
- global: the float8 UX allows setting options for the 3 gemms, and under the hood maps it to be implemented via (1).
While not the most intuitive to implement, I think that could work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separate PR sgtm, I do feel like we need to make all 3 gemms configurable before we lock the API down.
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, can we add instructions to README.md how to use this, and maybe mention that this is an intermediate state and we will expose an option to configure the 2 backward gemms separately in a future PR? thanks for adding this!
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
This does two things:
ScaledMMConfig
that is used to control the behavior of the scaled_mm op. This includes, emulate, fast_accumulation, and fp8_out_dtype(the latter is not currently used). It replaces the emulate arg and strings it through all the relevant infra, and updates test accordingly.Performance
With settings use_fast_accum in the forward using the linear_float8 benchmark: