Skip to content

Commit

Permalink
Enables On Demand all for the AutoAwq Kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
gnpinkert committed Oct 10, 2024
1 parent 5874d30 commit 98e3f69
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)

def load_to_gpu(self):
self.quant_method.load_to_gpu(layer=self)

def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output = self.quant_method.apply(layer=self, x=x, bias=bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

Expand Down
22 changes: 19 additions & 3 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class AWQMarlinLinearMethod(LinearMethodBase):

def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
self.qweight_gpu = None

def load_to_gpu(self, layer: torch.nn.Module):
self.qweight_gpu = layer.qweight.to('cuda')

def create_weights(
self,
Expand All @@ -152,6 +156,9 @@ def create_weights(
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
local_device = 'cuda'
if layer.__class__.__qualname__ == 'ReplicatedLinear':
local_device = 'cpu'

# Normalize group_size
if self.quant_config.group_size != -1:
Expand All @@ -170,6 +177,7 @@ def create_weights(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device=local_device
),
input_dim=0,
output_dim=1,
Expand Down Expand Up @@ -258,10 +266,11 @@ def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
gpu_qweight=None
) -> torch.Tensor:
return apply_awq_marlin_linear(
val = apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight=self.qweight_gpu if self.qweight_gpu is not None else layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
Expand All @@ -270,4 +279,11 @@ def apply(
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias)
bias=bias)

if self.qweight_gpu is not None:
del self.qweight_gpu
torch.cuda.empty_cache()
self.qweight_gpu = None

return val
16 changes: 13 additions & 3 deletions vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
def load_to_gpu(self):
self.w1.load_to_gpu()
self.w2.load_to_gpu()
self.w3.load_to_gpu()



class MixtralMoE(nn.Module):
Expand Down Expand Up @@ -136,16 +141,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)

routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = None
unique_values = torch.unique(selected_experts)
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
if expert_idx not in unique_values:
continue

expert_mlp = self.experts[expert_idx]
expert_mlp.load_to_gpu()
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)

current_hidden_states = expert_layer(hidden_states).mul_(
current_hidden_states = expert_mlp(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
Expand Down

0 comments on commit 98e3f69

Please sign in to comment.