Skip to content

Commit

Permalink
could apply loss function on inspector tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Achazwl committed Oct 5, 2022
1 parent 9e27b46 commit 10d192c
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 240 deletions.
1 change: 0 additions & 1 deletion bmtrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .param_init import init_parameters, grouped_parameters
from .utils import print_block, print_dict, print_rank, see_memory
from .synchronize import synchronize, sum_loss, wait_loader, gather_result
from .checkpointing import checkpoint
from .block_layer import CheckpointBlock, TransformerBlockList
from .wrapper import BMTrainModelWrapper
from .pipe_layer import PipelineTransformerBlockList
Expand Down
57 changes: 36 additions & 21 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len
for it in inspector.hidden_states:
debug.append("_inspect_hidden_states", it)
ctx.inspect_list = inspector.hidden_states
return outputs

if not isinstance(outputs, list) and not isinstance(outputs, tuple):
outputs = [outputs]
return tuple([len(outputs)] + outputs + [hidden_state["tensor"] for hidden_state in inspector.hidden_states])

@staticmethod
def backward(ctx, *grad_outputs):
def backward(ctx, _, *grads):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
Expand Down Expand Up @@ -77,28 +80,29 @@ def backward(ctx, *grad_outputs):
flag = 2
else:
flag = 0
with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block, ctx.param_dict, flag):
with torch.enable_grad(), CheckpointBlockContext(ctx.block, ctx.param_dict, flag):
inp_args = all_inputs[:len_args]
inp_kwargs = {}
for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]):
inp_kwargs[k] = v
outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs)
with ScopedTensorInspectorContext() as inspector:
outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs)
if not isinstance(outputs, tuple):
outputs = (outputs,)

assert len(outputs) == len(grad_outputs)
assert len(outputs) + len(inspector.hidden_states) == len(grads)

outputs_with_grad = []
grad_of_output = []
for i, output in enumerate(outputs):
if torch.is_tensor(output) and output.requires_grad:
outputs_with_grad.append(output)
grad_of_output.append(grad_outputs[i])
grad_of_output.append(grads[i])

# calculate gradients for inputs, also for parameters
torch.autograd.backward(
outputs_with_grad,
grad_of_output,
outputs_with_grad + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
grad_of_output + list(grads[len(outputs):]),
)
assert len(ctx.inspect_list) == len(inspector.hidden_states), "Backward step changed"
for i, it in enumerate(inspector.hidden_states):
Expand All @@ -108,6 +112,7 @@ def backward(ctx, *grad_outputs):

# change the tensor in placeholder
ctx.inspect_list[i]["tensor"] = it["tensor"]
ctx.inspect_list[i]["requires_grad"] = it["requires_grad"]

grads = []
for inp, requires_grad in zip(all_inputs, input_reqires_grad):
Expand Down Expand Up @@ -442,19 +447,26 @@ def __call__(self, *args, **kwargs):
for kw, val in kwargs.items():
all_inputs.append(kw)
all_inputs.append(val)
return OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs)
outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs)
len_output = outputs[0]
return outputs[1:1+len_output] if len_output > 1 else outputs[1]

def __getattr__(self,name:str):
if name=="_module":
return self._module
return getattr(self._module, name)

def __setattr__(self, name, value):
object.__setattr__(self, name, value)

def __getattribute__(self, name: str):
if name=="_parameters":
return self._module._parameters
return super().__getattribute__(name)

def __delattr__(self, name):
object.__delattr__(self, name)

def _save_to_state_dict(self, destination, prefix, keep_vars):
raise RuntimeError("._save_to_state_dict() of CheckpointBlock should not be called")

Expand Down Expand Up @@ -614,6 +626,7 @@ def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool =
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m

def named_children(self):
return self._module.named_children()

Expand Down Expand Up @@ -647,8 +660,8 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
layer_inputs = []
layer_inspector = []
cuda_rng_state = []
with torch.no_grad():
for i in range(len(self)):
for i in range(len(self)):
with torch.no_grad():
if save_list[i][0] == i:
layer_inputs.append(hidden_state.detach())
cuda_rng_state.append( torch.cuda.get_rng_state() )
Expand All @@ -662,10 +675,10 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
# call inner module directly
with ScopedTensorInspectorContext() as inspector:
hidden_state = self._modules[str(i)]._module._call_impl(hidden_state, *args)
for it in inspector.hidden_states:
debug.append("_inspect_hidden_states", it)
layer_inspector.append(inspector.hidden_states)
block_ctx.exit()
for it in inspector.hidden_states:
debug.append("_inspect_hidden_states", it)
layer_inspector.append(inspector.hidden_states)

ctx.layer_inspector = layer_inspector
ctx.cuda_rng_state = cuda_rng_state
Expand All @@ -677,13 +690,13 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
for mid in middle_hiddens:
mid.requires_grad_()
middle_hiddens = torch.stack(middle_hiddens, dim=0)
return hidden_state, middle_hiddens
else:
return hidden_state, None
middle_hiddens = None
return tuple([hidden_state, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])


@staticmethod
def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle: List[torch.Tensor]):
def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle: List[torch.Tensor], *grad_inspectors):
def exit_prev(prev_ctx, prev_grad):
if prev_ctx is not None:
if prev_grad:
Expand Down Expand Up @@ -762,12 +775,13 @@ def exit_prev(prev_ctx, prev_grad):
assert it["group"] == ctx.layer_inspector[i][j]["group"], "Backward step changed"

# change the tensor in placeholder
ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"]
ctx.layer_inspector[i][j]["tensor"] = it["tensor"]
ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"]
torch.autograd.backward(
[output],
[grad_hidden_state]
[output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
(grad_hidden_state,) + grad_inspectors[-len(inspector.hidden_states):],
)
grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)]
grad_hidden_state = ipt.grad
if grad_middle is not None:
grad_hidden_state = grad_hidden_state + grad_middle[i]
Expand Down Expand Up @@ -845,7 +859,8 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock:
def forward(self, hidden_state, *args, return_hidden_states = False):
self.return_hidden_states = return_hidden_states
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
last_hidden, middle_hiddens = OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args)
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args)
last_hidden, middle_hiddens = outputs[:2]
if return_hidden_states:
return last_hidden, middle_hiddens
else:
Expand Down
98 changes: 0 additions & 98 deletions bmtrain/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,101 +28,3 @@ def __exit__(self, *args):
self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", []))
debug.set("_inspect_hidden_states", self.prev_hidden)
self.prev_hidden = None

class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, placeholder, func, preserve_rng_state, *args):
ctx.func = func
ctx.preserve_rng_state = preserve_rng_state

ctx.cuda_rng_state = torch.cuda.get_rng_state() if preserve_rng_state else None

tensors = []
others = []
for arg in args:
if torch.is_tensor(arg):
tensors.append(arg)
others.append(None)
else:
tensors.append(None)
others.append(arg)
ctx.nontensor_inputs = others
ctx.save_for_backward(*tensors)

with torch.no_grad(), ScopedTensorInspectorContext() as inspector:
outputs = func(*args)

# append scoped hidden states to global list as a placeholder
for it in inspector.hidden_states:
debug.append("_inspect_hidden_states", it)
ctx.inspect_list = inspector.hidden_states

return outputs

@staticmethod
def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")

all_inputs = []
input_reqires_grad = []
for tensor, other in zip(ctx.saved_tensors, ctx.nontensor_inputs):
if tensor is None:
all_inputs.append(other)
input_reqires_grad.append(False)
else:
input_reqires_grad.append( tensor.requires_grad )
nw_tensor = tensor.detach()
nw_tensor.requires_grad = tensor.requires_grad
all_inputs.append(nw_tensor)


with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.cuda.set_rng_state(ctx.cuda_rng_state)
with torch.enable_grad(), ScopedTensorInspectorContext() as inspector:
outputs = ctx.func(*all_inputs)

assert len(ctx.inspect_list) == len(inspector.hidden_states), "Backward step changed"
for i, it in enumerate(inspector.hidden_states):
assert it["name"] == ctx.inspect_list[i]["name"], "Backward step changed"
assert it["shape"] == ctx.inspect_list[i]["shape"], "Backward step changed"
assert it["group"] == ctx.inspect_list[i]["group"], "Backward step changed"

# change the tensor in placeholder
ctx.inspect_list[i]["tensor"] = it["tensor"]
if not isinstance(outputs, tuple):
outputs = (outputs,)

assert len(outputs) == len(grad_outputs)

outputs_with_grad = []
grad_of_output = []
for i, output in enumerate(outputs):
if torch.is_tensor(output) and output.requires_grad:
outputs_with_grad.append(output)
grad_of_output.append(grad_outputs[i])

torch.autograd.backward(
outputs_with_grad,
grad_of_output,
)
grads = []
for inp, requires_grad in zip(all_inputs, input_reqires_grad):
if requires_grad:
grads.append(inp.grad)
else:
grads.append(None)
return (None, None, None) + tuple(grads)


R = TypeVar("R")
def checkpoint(func : Callable[..., R]) -> Callable[..., R]:
@wraps(func)
def wrapper(*args):
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
return CheckpointFunction.apply(placeholder, func, True, *args)
return wrapper
Loading

0 comments on commit 10d192c

Please sign in to comment.