-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Idea: Scaling the Down-Projection Matrix in 'Mixture of Experts' Models #294
Comments
Also it would be nice if we could post these sort of ideas to 'Discussions' instead of 'Issues'! :) |
Well there is a pretty big problem with the discrete uniform distribution assumption and it's causing the weights to be scaled far too much... So without actually being able to measure anything the next best assumption is a Zipf distribution: import numpy as np
def zipf_distribution(N, s):
"""Generate Zipf distribution for N experts with parameter s."""
ranks = np.arange(1, N+1)
weights = 1 / np.power(ranks, s)
normalization = np.sum(weights)
probabilities = weights / normalization
return probabilities
def expected_norm_squared(probabilities, num_experts):
"""Calculate the expected norm squared for a subset of experts."""
return np.sum(probabilities[:num_experts]**2)
def calculate_scaling_factor(N, n, m, s):
"""Calculate the scaling factor alpha for given N, n, m, and s."""
probabilities = zipf_distribution(N, s)
norm_squared_n = expected_norm_squared(probabilities, n)
norm_squared_m = expected_norm_squared(probabilities, m)
alpha = np.sqrt(norm_squared_n / norm_squared_m)
return alpha
N = 8 # num_local_experts
n = 2 # num_experts_per_tok
s = 0 # Skew parameter (0 = Uniform, 0.5 = Square-Root, 1 = Zipf's law)
# Print the Zipf distribution for the given s
probabilities = zipf_distribution(N, s)
print(f"Zipf distribution for s = {s}: {[f'{p:.4f}' for p in probabilities]}")
# Loop over all values of m from 1 to N
for m in range(1, N+1):
alpha = calculate_scaling_factor(N, n, m, s)
print(f"For m = {m}, Scaling factor alpha: {alpha:.4f}")
I'll see if I can run a grid-search overnight. |
Here's the # mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw//wiki.test.raw -ngl 1000
const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1
############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################
#const_tag: &RESIDUAL_SCALE_FACTOR 1.1180 # [s=0 --> 7.2995]
#const_tag: &RESIDUAL_SCALE_FACTOR 1.0 # 4.4103 +/- 0.02355
const_tag: &RESIDUAL_SCALE_FACTOR 0.9583 # [s=0 --> 4.6758]
# The `down_proj` of each MLP expert seems to be held in the `w2.weight` tensor for Mixtral:
# > current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
# > current_hidden_states = self.w2(current_hidden_states)
models:
- model: *MODEL
parameters:
scale:
- filter: w2.weight
value: *RESIDUAL_SCALE_FACTOR
- value: 1.0
dtype: bfloat16
merge_method: passthrough |
This isn't doing much useful... For 3 experts:
and
vs 2 experts & stock settings --> PPL = 4.4103 +/- 0.02355 It still may be useful to try setting |
So next I'm going to try to attenuate the MOE-routing softmax-gate's distribution:
and then the score matrix like we did for the frankenmerges:
|
Not really worth bothering with with I think. At best just going to get something about the same but slower to run:
|
Maybe this idea does have some use after all. If we can scale the gate weight tensor with n=8 to work as closely as possible to n=2, then very low bit quantized models using 2-3 bpw might actually work better and see their perplexity grow less slowly (due to more active weights cancelling out more of the noise caused by quantization). This assumes that the optimal scale factor doesn't just approximate the hard n=2 thresholding with a soft n=8 version that barely uses the other 6 sets of MLP weights (ie: doesn't shift the lower valued logits so far down that the Gumbel error distributions effectively head towards -inf and contribute almost nothing to the gated sum...). |
Problem
In a Mixture of Experts (MoE) LLM, the gating network outputs a categorical distribution of$n$ values (chosen from $n_{max}$ ), which is then used to create a convex combination of the $n$ outputs of the chosen expert MLP blocks only (eg: $n$ =2 and $n_{max}$ = 8 for $n$ experts and we want to change the chosen number of experts to $m$ , how should we scale the down-projection matrix of the MLP to maintain the expected norm of the final output?
Mixtral-8x7b
andMixtral-8x22b
). If the model was trained to choose only the topSolution
For simplicity, let's assume that the output of each expert is an i.i.d. random vector with a norm of$r$ and the gating network outputs a discrete uniform distribution where $g_i = \frac{1}{n}$ for all $i$ . The final output is a convex combination of the expert outputs:
The expected norm of this output is:
NOTE: The last equality holds only for a balanced distribution, where$g_i = \frac{1}{n}$ for all $i$ .
If we change the number of experts to$m$ , and the gating network outputs a balanced distribution over $m$ experts, the expected norm of the output becomes:
To make the expected norm of the output with$m$ experts equal to the expected norm of the output with $n$ experts, we need to scale the down-projection matrix of the MLP by a factor of $\sqrt{\frac{n}{m}}$ :
With this scaling, the expected norm of the output with$m$ experts becomes:
Which is the same as the expected norm of the output with$n$ experts.
Scale Factor
The scale factor$\sqrt{\frac{n}{m}}$ depends only on the ratio of the original number of experts ($n$ ) to the new number of experts ($m$ ). It does not depend on the norm $r$ of the expert outputs (with the given assumptions...).
(sorry for the AI generated text again - but it's so much easier than trying to write all that Latex!)
This all assumes I have correctly understood what the Mixtral-style MoE architecture is doing though (it's not 100% clear from the paper).
If this shows promise then the i.i.d. assumption and the discrete uniform distribution simplification can be removed by sampling the actual outputs of the expert MLPs / gating networks (the i.i.d. assumption can be improved on if we are happy to just guess values for$\rho$ [see the other thread for example], but to use a concrete categorical distribution we would need to sample from it I think).
I'm going to try this on
Mixtral-8x7b-Instruct
now and see if it improves the perplexity vs pervious attempts:https://rentry.org/HowtoMixtral
https://old.reddit.com/r/LocalLLaMA/comments/18m6zjz/for_exllamav2_how_many_mixtral_experts_are/
@cg123 I see you already have a parameter called$\frac{1}{\sqrt{m}}$ , etc.
residual_scale
so for themergekit-moe
merges it should be pretty easy to try scaling the models designed to not be in a MOE byThe text was updated successfully, but these errors were encountered: