Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Schedule] Create subschedule for subgraph replacement #52

Merged
merged 10 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/bert/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def new_expand(tensor, *args):
return out.contiguous()

for op in ops:
sch.replace(new_expand, op[1])
sch.replace(new_expand, op)


def replace_and_shard_attention(
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def new_expand(tensor, *args):
return out.contiguous()

for op in ops:
sch.replace(new_expand, op[1])
sch.replace(new_expand, op)


def replace_and_shard_attention(
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt2/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def new_expand(tensor, *args):
return out.contiguous()

for op in ops:
sch.replace(new_expand, op[1])
sch.replace(new_expand, op)


class AttentionOpWithRNG(FlashAttentionOp):
Expand Down
4 changes: 2 additions & 2 deletions examples/opt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def new_expand(tensor, *args):
return out.contiguous()

for op in ops:
sch.replace(new_expand, op[1])
sch.replace(new_expand, op)


def replace_and_shard_attention(
Expand Down Expand Up @@ -152,7 +152,7 @@ def remove_cast(sch, config, attn_path="h.N.attn.attention"):
)

for op in ops:
sub_sch.replace(lambda x, *args: x, op[1])
sub_sch.replace(lambda x, *args: x, op)
cnt += 1
return cnt

Expand Down
2 changes: 1 addition & 1 deletion examples/roberta/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def new_expand(tensor, *args):
return out.contiguous()

for op in ops:
sch.replace(new_expand, op[1])
sch.replace(new_expand, op)


def replace_and_shard_attention(
Expand Down
55 changes: 38 additions & 17 deletions slapo/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from .tracer import trace as trace_module
from .initialization import init_empty_weights
from .op.linear import LinearWithSeparateBias
from .utils.common import transfer_hooks, is_lambda_function
from .utils.common import (
transfer_hooks,
transfer_hooks_for_fusion,
is_lambda_function,
)
from .utils.mapping import MAPPING_FROM_FUNCTIONAL_TO_MODULE
from .pattern import Pattern, ModulePattern, call_module

Expand Down Expand Up @@ -744,10 +748,12 @@ def replace_function(self, func, target_op):
----------
func : Callable
The new function to replace the current function.
target_op : torch.fx.Node
target_op : List[Tuple[str, torch.fx.Node]]
The call_function node to be replaced.
The string in the tuple is the name of the parent module that
the node belongs to.
"""
node = target_op
_, node = target_op
with self.mod.graph.inserting_after(node):
new_node = self.mod.graph.call_function(func, node.args, node.kwargs)
node.replace_all_uses_with(new_node)
Expand All @@ -757,6 +763,9 @@ def replace_module(self, new_mod, subgraphs=None, name=None):
"""Replace an entire module with a new one.
Do NOT directly call this function, use `.replace()` instead.

If subgraphs is None, replace the whole self module;
Otherwise, replace target forward subgraphs with the new module.

Parameters
----------
new_mod : torch.nn.Module
Expand All @@ -769,28 +778,36 @@ def replace_module(self, new_mod, subgraphs=None, name=None):
will be automatically generated.
"""
if subgraphs is None:
# If subgraphs is None, replace the whole self module and the schedule.
new_sch = create_schedule(
new_mod, self.name, self.path, self.parent, self.group
)

if name is not None and name != self.name:
logger.warning(
"Cannot change the name of %s when replacing the whole module. "
"The given name %s will be ignored",
self.name,
name,
)
name = self.name
else:
if name is None:
name = new_mod._get_name().split(".")[-1]
name = _get_unique_module_name(self.mod, name)
# Create a new schedule for the replaced module
new_sch = create_schedule(new_mod, name, self.path, self.parent, self.group)
# Replace the corresponding part in the current module
if subgraphs is None:
# If subgraphs is None, replace the whole self module.
# Transfer hooks from the old module to the new module.
transfer_hooks(self.mod, new_sch.mod)

# Update schedules
self.mod = new_sch.mod
self.child = new_sch.child
for _, sch in self.child.items():
sch.parent = self
if self.parent:
self.update_submodule(self.parent.mod, self.name, new_mod)
else:
# Otherwise, replace target forward subgraphs with the new module.
# Note that this requires the current module in torch.fx so it
# has to be traced.
# Replacing target forward subgraphs with the new module
# requires the current module in torch.fx so it has to be traced.
self.trace()
if name is None:
name = new_mod._get_name().split(".")[-1]
name = _get_unique_module_name(self.mod, name)
assert len(subgraphs) > 0, "Should have at least one operator to replace"
if len(subgraphs) > 1:
# horizontal fusion, e.g.,
Expand Down Expand Up @@ -825,6 +842,7 @@ def replace_module(self, new_mod, subgraphs=None, name=None):
if node.users not in sublst:
node.replace_all_uses_with(getitem)
target_mod.graph.erase_node(node)
transfer_hooks_for_fusion(self, subgraphs, new_mod)
else:
# vertical fusion, e.g.,
# s0->v0
Expand Down Expand Up @@ -863,6 +881,9 @@ def replace_module(self, new_mod, subgraphs=None, name=None):
last_node.replace_all_uses_with(new_node)
for node in reversed(ops):
target_mod.graph.erase_node(node)
transfer_hooks_for_fusion(self, subgraphs, new_mod)
# Update schedules
self.child[name] = new_sch

@register_primitive()
def replace(self, new_mod_or_func, target_ops=None, name=None):
Expand All @@ -871,9 +892,9 @@ def replace(self, new_mod_or_func, target_ops=None, name=None):
2. Replace a part of the forward function (target_ops) with a new module or function.
"""
if isinstance(new_mod_or_func, FunctionType):
if target_ops is None and isinstance(target_ops, list):
if isinstance(target_ops, list):
raise ValueError(
"Cannot replace multiple subgraphs in forward with one function"
"Cannot replace multiple nodes in forward with one function"
)
self.replace_function(new_mod_or_func, target_ops)
else:
Expand Down
69 changes: 64 additions & 5 deletions slapo/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from types import FunctionType


HOOK_TYPE_TO_ATTR = {
"fwd_pre": "_forward_pre_hooks",
"fwd_post": "_forward_hooks",
"bwd_post": "_backward_hooks",
}


def get_hooks(mod):
"""Get the hooks of a module.

Expand All @@ -32,6 +39,10 @@ def get_hooks(mod):
return hooks


def has_hook(mod, hook_type):
return len(getattr(mod, HOOK_TYPE_TO_ATTR[hook_type])) > 0


def transfer_hooks(old_mod, new_mod, hook_types=None):
"""Transfer the hooks from old_mod to new_mod.

Expand All @@ -44,18 +55,66 @@ def transfer_hooks(old_mod, new_mod, hook_types=None):
hook_types : Optional[List[str]]
The types of hooks to transfer. If None, transfer all hooks.
"""
HOOK_TYPE_TO_ATTR = {
"fwd_pre": "_forward_pre_hooks",
"fwd_post": "_forward_hooks",
"bwd_post": "_backward_hooks",
}
if hook_types is None:
hook_types = ["fwd_pre", "fwd_post", "bwd_post"]

for hook_attr in [HOOK_TYPE_TO_ATTR[hook_type] for hook_type in hook_types]:
setattr(new_mod, hook_attr, getattr(old_mod, hook_attr))


def transfer_hooks_for_fusion(sch, subgraphs, new_mod):
"""Transfer hooks of modules in the subgraph to be fused.
For example, the fwd_pre hook of the first module in the subgraph will become
the fwd_pre hook of the fused module.
Note that if middle modules have hooks, we will throw errors because we cannot
keep these hooks in the fused module.

Parameters
----------
sch : Schedule
The parent schedule.
subgraphs : List[List[Tuple[Node, Node]]]
The fused subgraphs that need to be transferred hooks.
new_mod : torch.nn.Module
The new module that will be created after fusion.
"""
hook_types = HOOK_TYPE_TO_ATTR.keys()
if len(subgraphs) > 1:
# Since horizontal fusion needs to combine the hooks together,
# we cannot support it for now.
for i, sublst in enumerate(subgraphs):
for _, node in sublst:
if node.op != "call_module":
break
old_mod = sch.get_module(node.target)
for hook in hook_types:
if has_hook(old_mod, hook) > 0:
raise RuntimeError(
f"Cannot use horizontal fusion since module {node.target} has a {hook} hook"
)
else:
ops = subgraphs[0]
for i, (_, node) in enumerate(ops):
if node.op == "call_module":
old_mod = sch.get_module(node.target)
if i == 0: # the first node
if has_hook(old_mod, "fwd_post"):
raise RuntimeError(
f"Cannot transfer hooks from {node.target} to the new module since {node.target} has a fwd_post hook"
)
transfer_hooks(old_mod, new_mod, ["fwd_pre", "bwd_post"])
elif i == len(ops) - 1: # the last node
if has_hook(old_mod, "fwd_pre") or has_hook(old_mod, "bwd_post"):
raise RuntimeError(
f"Cannot transfer hooks from {node.target} to the new module since {node.target} has a fwd_pre/bwd_post hook"
)
transfer_hooks(old_mod, new_mod, ["fwd_post"])
elif any(has_hook(old_mod, x) for x in hook_types):
raise RuntimeError(
f"Cannot transfer hooks from {node.target} to the new module since {node.target} is in the middle of the subgraph"
)


def is_lambda_function(obj):
return isinstance(obj, FunctionType) and obj.__name__ == "<lambda>"

Expand Down
Loading