Skip to content

Commit

Permalink
Tensor Parallel (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 authored Aug 29, 2023
1 parent abc7b90 commit df43d6d
Show file tree
Hide file tree
Showing 24 changed files with 804 additions and 191 deletions.
44 changes: 29 additions & 15 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Iterable, Iterator, Union, List

from .utils import round_up
from .utils import (round_up, tp_split_tensor)
from .global_var import config
import torch
from . import nccl
Expand Down Expand Up @@ -94,7 +94,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev
"total": 0,
"storage_type": storage_type,
"requires_grad": param.requires_grad,
"group": param.group
"group": param.group,
"zero_comm" : param._zero_comm
}

param_shape = param._original_shape
Expand All @@ -108,11 +109,14 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev
offsets = {}
# intialize storage buffers
for kw, val in self._storage_info.items():
val["world_size"] = config["world_size"]
comm = val['zero_comm']
world_size = nccl.commCount(comm)
rank = nccl.commRank(comm)
val["world_size"] = world_size
partition_size = round_up(val["total"], val["world_size"]) // val["world_size"]
val["partition_size"] = partition_size
val["begin"] = config['rank'] * partition_size
val["end"] = (config['rank'] + 1) * partition_size
val["begin"] = rank * partition_size
val["end"] = (rank+1) * partition_size
offsets[kw] = 0


Expand Down Expand Up @@ -302,13 +306,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
if key in state_dict:
# load here
input_param = state_dict[key]
param = it['parameter']
tp_mode = param._tp_mode
if input_param.__class__.__name__ == "DistributedTensorWrapper":
input_param = input_param.broadcast()
if input_param.shape != it["shape"]:

verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape)
if input_param.shape != verify_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, it["shape"]))
.format(key, input_param.shape, verify_shape))
continue

param_st = it["offset"]
param_end = it["offset"] + it["size"]
kw_name = it["kw_name"]
Expand All @@ -322,16 +331,22 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
continue

# copy to buffer
assert input_param.numel() == it["size"]
verify_size = verify_shape.numel()
assert input_param.numel() == verify_size

contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous()

tp_split_dim = param._tp_split_dim
if tp_mode and tp_split_dim >= 0:
contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim)

offset_st = max(storage_st - param_st, 0)
offset_end = min(storage_end - param_st, contiguous_param.numel())
assert offset_st < offset_end

to_offset_st = offset_st + param_st - storage_st
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
Expand Down Expand Up @@ -398,7 +413,7 @@ def init_parameters(self):
param = it["parameter"]
if isinstance(param, DistributedParameter) and param._init_method is not None:
# initialzie here
tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype)
tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype)
param._init_method(tmp_tensor)
param_st = it["offset"]
param_end = it["offset"] + it["size"]
Expand All @@ -412,16 +427,15 @@ def init_parameters(self):
if param_end <= storage_st:
continue

if param._tp_mode and param._tp_split_dim >= 0:
tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim)
# copy to buffer
assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel()

offset_st = max(storage_st - param_st, 0)
offset_end = min(storage_end - param_st, tmp_tensor.numel())
offset_st = max(storage_st - param_st, 0)
offset_end = min(storage_end - param_st, tmp_tensor.numel())
assert offset_st < offset_end

to_offset_st = offset_st + param_st - storage_st
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
Expand Down
9 changes: 3 additions & 6 deletions bmtrain/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal
self._param_tensor = {}
self._grad_tensor = {}
self._need_release = False
if pipe:
self.comm = config["zero_comm"]
else:
self.comm = config["comm"]

def enter(self, flag=0, requires_grad=False):
"""
gather parameters
Expand Down Expand Up @@ -74,7 +71,7 @@ def enter(self, flag=0, requires_grad=False):
nccl.allGather(
self.block._storage_params[kw].storage(),
self._param_buffer[kw],
self.comm
val['zero_comm']
)
nccl.groupEnd()

Expand Down Expand Up @@ -144,7 +141,7 @@ def exit(self, flag=0, backward=False):
self._grad_buffer[kw],
local_param.grad.storage(),
"sum",
self.comm
val['zero_comm']
)
nccl.groupEnd()

Expand Down
84 changes: 61 additions & 23 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
from . import nccl
from .synchronize import synchronize


def init_distributed(
init_method : str = "env://",
seed : int = 0,
pipe_size: int = -1,
num_micro_batches: int = None,
tp_size : int = 1,
):
"""Initialize distributed training.
This function will initialize the distributed training, set the random seed and global configurations.
It must be called before any other distributed functions.
Args:
seed (int): The random seed.
pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups
num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode.
tp_size (int) : tp_size means the size of each of tensor parallel group
**init_distributed** reads the following environment variables:
Expand Down Expand Up @@ -70,10 +73,15 @@ def init_distributed(
config["world_size"] = world_size
config["calc_stream"] = torch.cuda.current_stream()
config["load_stream"] = torch.cuda.Stream(priority=-1)
config["tp_comm_stream"] = torch.cuda.Stream(priority=-1)
config["pp_comm_stream"] = torch.cuda.Stream(priority=-1)
config['barrier_stream'] = torch.cuda.Stream()
config["load_event"] = torch.cuda.Event()
config["tp_size"] = tp_size if tp_size > 0 else 1
config["topology"] = topology(config)
config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank']
config["zero_rank"] = config['topology'].get_group_rank("zero")
config["tp_rank"] = config['topology'].get_group_rank("tp")
config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero")
cpus_this_worker = None

all_available_cpus = sorted(list(os.sched_getaffinity(0)))
Expand Down Expand Up @@ -102,21 +110,34 @@ def init_distributed(

unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode())
config['comm'] = nccl.commInitRank(unique_id, world_size, rank)
topo = config['topology']

if config['pipe_enabled']:
config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"]
topo = config['topology']
if topo.stage_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex())
unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode())
config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id)
if topo.zero_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() )
unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode())
config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//pipe_size, topo.zero_id)
else:
config['zero_comm'] = config['comm']

if topo.tp_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex())
unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode())
config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id)

if topo.tp_zero_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() )
unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode())
config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id)

if topo.zero_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() )
unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode())
config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id)

for i in range(world_size):
if i == rank:
print_dict("Initialization", {
Expand All @@ -129,40 +150,57 @@ def init_distributed(
"cpus": cpus_this_worker
})
synchronize()

class topology:
def __init__(self,config):
# pipe_idx is the idx of the pipeline in the group
self.rank = config['rank']
pp_size = config["pipe_size"]
tp_size = config["tp_size"]
world_size = config["world_size"]
assert world_size % pp_size == 0, "The nums of GPUs must be divisible by the pipeline parallel size"

dp_size = world_size // pp_size
topo=torch.tensor(range(dp_size*pp_size),dtype=torch.int,device='cuda')
topo=topo.view(pp_size,dp_size)
self.pp_group=topo.transpose(0,1).reshape(-1,pp_size)
self.dp_group=topo
self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item()
assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size"

dp_size = world_size // (pp_size * tp_size)
config['tp_zero_size'] = dp_size
config['zero_size'] = world_size // pp_size
topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda')
topo=topo.view(pp_size,dp_size*tp_size)
self.stages = config['pipe_size']
self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes
self.zero_id = self.pipe_idx
self.zero_idx = self.stage_id

stage_size = world_size // pp_size
for i in range(world_size):
self.pipe_idx = self.rank % stage_size
self.stage_id = self.rank // stage_size
self.tp_id = self.rank % tp_size
self.tp_idx = self.rank // tp_size
self.zero_idx = self.stage_id
self.zero_id = self.pipe_idx
self.tp_zero_idx = self.stage_id * tp_size + self.tp_id
self.tp_zero_id = self.pipe_idx // tp_size

self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1
self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1
self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist()
self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist()


def get_group_id(self,group_name):
if group_name == "pipe":
return self.pipe_idx
elif group_name == "zero":
return self.zero_idx
elif group_name == "tp_zero":
return self.tp_zero_idx
elif group_name == "tp":
return self.tp_idx

def get_group_rank(self,group_name):
if group_name == "pipe":
return self.stage_id
elif group_name == "zero":
return self.zero_id
elif group_name == "tp_zero":
return self.tp_zero_id
elif group_name == "tp":
return self.tp_id

def is_initialized() -> bool:
return config["initialized"]
Expand Down
20 changes: 15 additions & 5 deletions bmtrain/layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from .parameter import DistributedParameter
from .global_var import config
import itertools
from .utils import tp_split_tensor

class DistributedModule(torch.nn.Module):
"""
Expand All @@ -11,7 +13,7 @@ class DistributedModule(torch.nn.Module):
def __getattr__(self, name: str):
ret = super().__getattr__(name)
# gather distributed parameters if not in CheckpointBlock
if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block:
if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block:
return ret.gather()
return ret

Expand All @@ -30,8 +32,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
for name, param in self._parameters.items():
if param is not None:
if isinstance(param, DistributedParameter) and not param._in_checkpoint_block:
destination[prefix + name] = param.gather().detach().cpu() # sync operation
if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block:
if param._in_checkpoint_block:
destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation
else:
destination[prefix + name] = param.gather_all().detach().cpu() # sync operation
else:
destination[prefix + name] = param if keep_vars else param.detach().cpu()
for name, buf in self._buffers.items():
Expand Down Expand Up @@ -81,6 +86,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
tp_mode = param._tp_mode
input_param = state_dict[key]
if input_param.__class__.__name__ == "DistributedTensorWrapper":
input_param = input_param.broadcast()
Expand All @@ -98,13 +104,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape:
verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape)
if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
.format(key, input_param.shape, verify_shape))
try:
with torch.no_grad():
if isinstance(param, DistributedParameter):
tp_split_dim = param._tp_split_dim
if tp_mode and tp_split_dim >= 0:
input_param = tp_split_tensor(input_param, tp_split_dim)
param._copy_data(input_param)
else:
param.copy_(input_param)
Expand Down
Loading

0 comments on commit df43d6d

Please sign in to comment.