Skip to content

Commit

Permalink
[Refactor] model_dialect -> framework_dialect (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Feb 11, 2023
1 parent 6c2e235 commit d2bcf59
Show file tree
Hide file tree
Showing 15 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion benchmark/deepspeed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from slapo.model_dialect import get_dialect_cls
from slapo.framework_dialect import get_dialect_cls


def add_deepspeed_parser(common_parser, subprasers):
Expand Down
2 changes: 1 addition & 1 deletion benchmark/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from slapo.model_dialect import get_dialect_cls
from slapo.framework_dialect import get_dialect_cls


def add_megatron_parser(common_parser, subprasers):
Expand Down
2 changes: 1 addition & 1 deletion slapo/autotune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time

from slapo.logger import get_logger
from slapo.model_dialect import get_dialect_cls
from slapo.framework_dialect import get_dialect_cls


logger = get_logger()
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from ..registry import register_model_dialect
from ..registry import register_framework_dialect
from ...logger import get_logger, INFO

logger = get_logger("DS-Engine", INFO)


@register_model_dialect("deepspeed", "runtime_engine")
@register_framework_dialect("deepspeed", "runtime_engine")
def init_ds_engine(model, **kwargs):
"""Initialize the DeepSpeed engine."""
import deepspeed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import distributed as dist
from torch import fx, nn

from ..registry import register_model_dialect
from ..registry import register_framework_dialect
from ...logger import get_logger, DEBUG, INFO

# Change INFO to DEBUG for more verbose logging.
Expand Down Expand Up @@ -196,7 +196,7 @@ def analyze_tie_ranks(tie_weight_groups, topology):
return tie_ranks, tie_stages


@register_model_dialect("deepspeed", "pipeline_stage")
@register_framework_dialect("deepspeed", "pipeline_stage")
class DeepSpeedPipeStageWrapper(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -323,7 +323,7 @@ def forward(self, *args, **kwargs):
return ret


@register_model_dialect("deepspeed", "pipeline_engine")
@register_framework_dialect("deepspeed", "pipeline_engine")
def deepspeed_pipe_engine(
sch_metadata,
stage_modules,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import re

from ..registry import register_model_dialect
from ..registry import register_framework_dialect


@register_model_dialect("deepspeed", "log_parser")
@register_framework_dialect("deepspeed", "log_parser")
class DeepSpeedLogParser:
@staticmethod
def parse_log(log_filename):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import re

from ..registry import register_model_dialect
from ..registry import register_framework_dialect


@register_model_dialect("megatron", "log_parser")
@register_framework_dialect("megatron", "log_parser")
class MegatronLogParser:
@staticmethod
def parse_log(log_filename):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Framework model dialect registration."""
"""Framework dialect registration."""

DIALECTS = {
"pipeline_stage": {},
Expand All @@ -10,15 +10,15 @@
}


def register_model_dialect(target, cls_type):
"""Register a framework model dialect."""
def register_framework_dialect(target, cls_type):
"""Register a framework dialect."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")

def decorator(dialect_cls):
if "target" in DIALECTS[cls_type]:
raise ValueError(
f"Target {target} already registered for {cls_type} model dialects"
f"Target {target} already registered for {cls_type} dialects"
)
DIALECTS[cls_type][target] = dialect_cls
return dialect_cls
Expand All @@ -27,14 +27,14 @@ def decorator(dialect_cls):


def get_all_dialects(cls_type):
"""Get all registered framework model dialects."""
"""Get all registered framework dialects."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")
return DIALECTS[cls_type]


def get_dialect_cls(cls_type, target, allow_none=False):
"""Get the framework model dialect class."""
"""Get the framework dialect class."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")
if target not in DIALECTS[cls_type]:
Expand All @@ -46,7 +46,5 @@ def get_dialect_cls(cls_type, target, allow_none=False):
f"Target {target} does not register default dialect for {cls_type}"
)
else:
raise ValueError(
f"Target {target} not registered for {cls_type} model dialects"
)
raise ValueError(f"Target {target} not registered for {cls_type} dialects")
return DIALECTS[cls_type][target]
2 changes: 1 addition & 1 deletion slapo/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.fx.passes.split_module import split_module

from .logger import get_logger
from .model_dialect import get_dialect_cls
from .framework_dialect import get_dialect_cls
from .utils.common import transfer_hooks

logger = get_logger()
Expand Down
2 changes: 1 addition & 1 deletion slapo/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .checkpoint import checkpoint as checkpoint_module
from .logger import get_logger
from .model_dialect import get_dialect_cls
from .framework_dialect import get_dialect_cls
from .pipeline import (
analyze_tie_weights,
build_pipeline_model,
Expand Down
2 changes: 1 addition & 1 deletion tests/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import pytest

from slapo.model_dialect import get_dialect_cls
from slapo.framework_dialect import get_dialect_cls


def parse_log(impl, log_file):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pipeline_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def forward(self, x):
model.stage2.linear.weight = model.stage1.linear.weight

tie_weights = list(slapo.pipeline.analyze_tie_weights(model, True).values())
tie_ranks, tie_stages = slapo.model_dialect.deepspeed.pipeline.analyze_tie_ranks(
(
tie_ranks,
tie_stages,
) = slapo.framework_dialect.deepspeed.pipeline.analyze_tie_ranks(
tie_weights, topology
)

Expand Down

0 comments on commit d2bcf59

Please sign in to comment.