Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement initial Metal port #6

Merged
merged 25 commits into from
Apr 11, 2024
Merged

Implement initial Metal port #6

merged 25 commits into from
Apr 11, 2024

Conversation

zeux
Copy link
Owner

@zeux zeux commented Apr 11, 2024

This change adds a Metal implementation for macOS; it's 95% functionally complete - it's missing mixture-of-experts - but hasn't been fully tuned for performance. On M1 base it's reaching ~91% bandwidth for fp16/fp8 models and ~65% for gf4 weights (on Mistral 7B).

zeux added 25 commits April 10, 2024 19:37
This is the initial scaffolding necessary for Metal support. We compile metal into a
metallib file that is then embedded as a C symbol.
This code initializes the device, loads the compiled shader and runs it to apply scale to the
buffer. All of this is very temporary and hacky, and just sets the stage for incremental
development.
run.c now uses more or less the same interface for Metal as it did for CUDA; now we "simply"
need to implement prepare/forward. In upload_metal we now upload the buffer and return the
buffer handle as pointer so that the rest of the code doesn't need to change.
For now let's hope the kernel search is not going to be significant wrt time it takes.
We assume every kernel call accepts parameter buffer at index 0 and the rest at indices 1+; the
wrapper finds the kernel and dispatches it as necessary.
This change needs to rework dispatch selection a bit to allow for templated kernels; we also
now properly implement forward flow (modulo the actual kernel implementation/dispatch).
Also fix a couple issues that surface under validation.
For now dispatch a separate kernel; in the future this can be fused into the next matmul kernel with shared memory.
For now we don't support MoE as that requires extra code for gating et al.
For now missing a KV rotate kernel. We're using the same K/V layout as CUDA does to
make porting code easier, efficiency will need to be understood later.
Also fix ffn2 buffer count (lack of type safety is really not ideal here...)
The code is mostly blindly translated from CUDA, keeping the layout et al. The attention
scores only need 1 float for now because we probably won't be able to do early softmax
easily.
The implementation closely follows CUDA in terms of data layout et al but is untested.
Metal validation layers don't like null buffer bindings; we'd need to allocate a zero buffer for this to work.
Without this we only computed q but didn't actually update KV cache.
The input value needs to be in 0..1024 range, so pass i for consistency.
We are now targeting Metal 3.0 explicitly to avoid depending on unclear default.
We need to dispatch one threadgroup per two elements, previously we would be racing over some elements
and computing them incorrectly, not to mention that the process would take more time than it needs to.
Metal doesn't support fp8 type natively but we can easily convert it to half using bit casts.
The matmul kernel is not super performant right now, but at least the weight format works.
This allows us to use two fmas which might be better. Also use floats
instead of halfs; we can figure out if halfs are worthwhile separately.
Instead of always using OpenMP from Homebrew, allow compiling and linking
without it; the CPU backend will be really slow in absence of omp but we now
have a Metal backend to use instead.
The code is basically a 1-1 copy from CUDA, except that we linearize the thread group dimension
as the dispatch helper only supports 1D dispatches.
When the model doesn't supply a bias tensor, we substitute a zero-initialized tensor ourselves; without this,
validation layers are unhappy if we try to dispatch the kernel without a bound buffer.
@zeux zeux merged commit 68069e4 into main Apr 11, 2024
4 checks passed
@zeux zeux deleted the metal branch April 11, 2024 23:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant