Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Allow for modifying the scaled_mm compute #144

Closed
wants to merge 6 commits into from
Closed

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Nov 16, 2023

Summary

This does two things:

  1. Creates a new named_tuple type ScaledMMConfig that is used to control the behavior of the scaled_mm op. This includes, emulate, fast_accumulation, and fp8_out_dtype(the latter is not currently used). It replaces the emulate arg and strings it through all the relevant infra, and updates test accordingly.
  2. This adds the fp8 fast accum mode and enables it for the forward path and not the backward pass.

Performance

With settings use_fast_accum in the forward using the linear_float8 benchmark:

image

shape Speedup_with_False Speedup_with_True Percentage_Gain
0 (16384, 1024, 8192) 1.19086 1.26397 6.13912
1 (16384, 3584, 8192) 1.42227 1.48921 4.70629
2 (16384, 8192, 1280) 0.970685 0.986167 1.59497
3 (16384, 8192, 7168) 1.50755 1.54886 2.74022

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 16, 2023
@drisspg drisspg mentioned this pull request Jan 16, 2024
@drisspg drisspg requested a review from vkuzo April 8, 2024 20:41
self.emulate = False
# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig()
self.backward_config = ScaledMMConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to configure the two backward gemms separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat challenging , since as written today we dont have a very clean way of knowing which matmul is which

cc @bdhirsh maybe Im no thinking of something.

We have out = x@W where x = FLoat8Tensor and W = Float8Tensor.

Since W will not be used in calucalting the gradW you could tag some extra info on the activation float8tensor and since this gets used for backward this should get carried through to backwards calcs.

I think that this would be better as a follow up though since the logic gets spread out over multiple Float8Tensor instances.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since as written today we dont have a very clean way of knowing which matmul is which

yeah, this is weird because the config is really per gemm but we have to stick it on a tensor. How about something like

  1. local: given matmul(A, B), the B matmul (second argument) always overrides A (first argument).
  2. global: the float8 UX allows setting options for the 3 gemms, and under the hood maps it to be implemented via (1).

While not the most intuitive to implement, I think that could work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR sgtm, I do feel like we need to make all 3 gemms configurable before we lock the API down.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, can we add instructions to README.md how to use this, and maybe mention that this is an intermediate state and we will expose an option to configure the 2 backward gemms separately in a future PR? thanks for adding this!

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 31877bb.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants