Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support mc2 for mp lora. #8161

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
import os
from typing import List, Optional

import paddle
Expand All @@ -24,6 +25,12 @@
RowParallelLinear,
)

if "npu" in paddle.device.get_all_custom_device_type():
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear

Check warning on line 29 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L29

Added line #L29 was not covered by tests
else:
MC2LoRaRowParallelLinear = None
MC2LoRaColumnParallelLinear = None


class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
Expand Down Expand Up @@ -188,14 +195,17 @@
input_mp = x

# x @ W : [bz, in_f / ws] ===> [bz, out_f]
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)

Check warning on line 199 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L198-L199

Added lines #L198 - L199 were not covered by tests
else:
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)

Check warning on line 201 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L201

Added line #L201 was not covered by tests

output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = mp_ops._mp_allreduce(

Check warning on line 203 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L203

Added line #L203 was not covered by tests
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

212到219行也有一个matmul和allreduce的过程

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个测过,这里加fuse对性能没收益,而且会拖慢性能


if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
Expand Down Expand Up @@ -294,13 +304,21 @@
self.merged = True

def forward(self, input: paddle.Tensor):
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias

Check warning on line 309 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L307-L309

Added lines #L307 - L309 were not covered by tests
else:
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

Check warning on line 312 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L311-L312

Added lines #L311 - L312 were not covered by tests

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = tmp * self.scaling

Check warning on line 318 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L316-L318

Added lines #L316 - L318 were not covered by tests
else:
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling

Check warning on line 321 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L320-L321

Added lines #L320 - L321 were not covered by tests
result_mp += delta_mp

if self.gather_output and self.is_mp:
Expand Down
76 changes: 76 additions & 0 deletions paddlenlp/peft/lora/mc2_lora_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# !/usr/bin/env python3

# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" mc2(tp overlap) """

Check warning on line 17 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L17

Added line #L17 was not covered by tests

import paddle
import paddle_custom_device
from paddle.autograd import PyLayer

Check warning on line 21 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L19-L21

Added lines #L19 - L21 were not covered by tests


class MC2LoRaRowParallelLinear(PyLayer):
@staticmethod
def forward(ctx, input_, weight, group):
ctx.save_for_backward(input_, weight)
rank = paddle.distributed.get_rank()
hcom_name = group.process_group.get_comm_name(rank)
x = input_.reshape([-1, input_.shape[-1]])
out = paddle_custom_device.npu.fused_mm_allreduce(

Check warning on line 31 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L24-L31

Added lines #L24 - L31 were not covered by tests
x, weight, bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0
)
output = out.reshape([input_.shape[0], input_.shape[1], weight.shape[1]])
ctx.ring_id = group.id
return output

Check warning on line 36 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L34-L36

Added lines #L34 - L36 were not covered by tests

@staticmethod
def backward(ctx, dy):
input_, weight = ctx.saved_tensor()
out_grad = dy
sub_grad = out_grad.reshape([-1, out_grad.shape[-1]])
input_grad = paddle.matmul(sub_grad, weight.t())
if weight.stop_gradient:
return input_grad.reshape(input_.shape)

Check warning on line 45 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L38-L45

Added lines #L38 - L45 were not covered by tests
else:
input_reshape = input_.reshape([-1, input_.shape[-1]])
weight_grad = input_reshape.t().matmul(sub_grad)
return input_grad.reshape(input_.shape), weight_grad

Check warning on line 49 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L47-L49

Added lines #L47 - L49 were not covered by tests


class MC2LoRaColumnParallelLinear(PyLayer):
@staticmethod
def forward(ctx, input_, weight, group):
ctx.save_for_backward(input_, weight)
ctx.group = group
input_mp = input_
result_mp = paddle.matmul(input_mp, weight)
return result_mp

Check warning on line 59 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L52-L59

Added lines #L52 - L59 were not covered by tests

@staticmethod
def backward(ctx, dy):
input_, weight = ctx.saved_tensor()
sub_grad = dy.reshape([-1, dy.shape[-1]])
rank = paddle.distributed.get_rank()
hcom_name = ctx.group.process_group.get_comm_name(rank)

Check warning on line 66 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L61-L66

Added lines #L61 - L66 were not covered by tests

d_weight = input_.reshape([-1, input_.shape[-1]]).t().matmul(sub_grad) if not weight.stop_gradient else None
d_input = paddle_custom_device.npu.fused_mm_allreduce(

Check warning on line 69 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L68-L69

Added lines #L68 - L69 were not covered by tests
sub_grad, weight.t(), bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0
)

if d_weight is not None:
return d_input.reshape(input_.shape), d_weight

Check warning on line 74 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L73-L74

Added lines #L73 - L74 were not covered by tests
else:
return d_input.reshape(input_.shape)

Check warning on line 76 in paddlenlp/peft/lora/mc2_lora_npu.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/mc2_lora_npu.py#L76

Added line #L76 was not covered by tests
Loading