diff --git a/benchmark/deepspeed_utils.py b/benchmark/deepspeed_utils.py index dd4a4e41..ce9b2baa 100644 --- a/benchmark/deepspeed_utils.py +++ b/benchmark/deepspeed_utils.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from slapo.model_dialect import get_dialect_cls +from slapo.framework_dialect import get_dialect_cls def add_deepspeed_parser(common_parser, subprasers): diff --git a/benchmark/megatron_utils.py b/benchmark/megatron_utils.py index 8eb722bc..ea802659 100644 --- a/benchmark/megatron_utils.py +++ b/benchmark/megatron_utils.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from slapo.model_dialect import get_dialect_cls +from slapo.framework_dialect import get_dialect_cls def add_megatron_parser(common_parser, subprasers): diff --git a/slapo/autotune/tune.py b/slapo/autotune/tune.py index 23b6105d..47970ac2 100644 --- a/slapo/autotune/tune.py +++ b/slapo/autotune/tune.py @@ -13,7 +13,7 @@ import time from slapo.logger import get_logger -from slapo.model_dialect import get_dialect_cls +from slapo.framework_dialect import get_dialect_cls logger = get_logger() diff --git a/slapo/model_dialect/__init__.py b/slapo/framework_dialect/__init__.py similarity index 100% rename from slapo/model_dialect/__init__.py rename to slapo/framework_dialect/__init__.py diff --git a/slapo/model_dialect/deepspeed/__init__.py b/slapo/framework_dialect/deepspeed/__init__.py similarity index 100% rename from slapo/model_dialect/deepspeed/__init__.py rename to slapo/framework_dialect/deepspeed/__init__.py diff --git a/slapo/model_dialect/deepspeed/engine.py b/slapo/framework_dialect/deepspeed/engine.py similarity index 91% rename from slapo/model_dialect/deepspeed/engine.py rename to slapo/framework_dialect/deepspeed/engine.py index c6971ac7..62d511b4 100644 --- a/slapo/model_dialect/deepspeed/engine.py +++ b/slapo/framework_dialect/deepspeed/engine.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from ..registry import register_model_dialect +from ..registry import register_framework_dialect from ...logger import get_logger, INFO logger = get_logger("DS-Engine", INFO) -@register_model_dialect("deepspeed", "runtime_engine") +@register_framework_dialect("deepspeed", "runtime_engine") def init_ds_engine(model, **kwargs): """Initialize the DeepSpeed engine.""" import deepspeed diff --git a/slapo/model_dialect/deepspeed/pipeline.py b/slapo/framework_dialect/deepspeed/pipeline.py similarity index 98% rename from slapo/model_dialect/deepspeed/pipeline.py rename to slapo/framework_dialect/deepspeed/pipeline.py index a7208c1d..80470340 100644 --- a/slapo/model_dialect/deepspeed/pipeline.py +++ b/slapo/framework_dialect/deepspeed/pipeline.py @@ -8,7 +8,7 @@ from torch import distributed as dist from torch import fx, nn -from ..registry import register_model_dialect +from ..registry import register_framework_dialect from ...logger import get_logger, DEBUG, INFO # Change INFO to DEBUG for more verbose logging. @@ -196,7 +196,7 @@ def analyze_tie_ranks(tie_weight_groups, topology): return tie_ranks, tie_stages -@register_model_dialect("deepspeed", "pipeline_stage") +@register_framework_dialect("deepspeed", "pipeline_stage") class DeepSpeedPipeStageWrapper(nn.Module): def __init__( self, @@ -323,7 +323,7 @@ def forward(self, *args, **kwargs): return ret -@register_model_dialect("deepspeed", "pipeline_engine") +@register_framework_dialect("deepspeed", "pipeline_engine") def deepspeed_pipe_engine( sch_metadata, stage_modules, diff --git a/slapo/model_dialect/deepspeed/utils.py b/slapo/framework_dialect/deepspeed/utils.py similarity index 93% rename from slapo/model_dialect/deepspeed/utils.py rename to slapo/framework_dialect/deepspeed/utils.py index 0163d3e5..6d28b5c6 100644 --- a/slapo/model_dialect/deepspeed/utils.py +++ b/slapo/framework_dialect/deepspeed/utils.py @@ -3,10 +3,10 @@ import re -from ..registry import register_model_dialect +from ..registry import register_framework_dialect -@register_model_dialect("deepspeed", "log_parser") +@register_framework_dialect("deepspeed", "log_parser") class DeepSpeedLogParser: @staticmethod def parse_log(log_filename): diff --git a/slapo/model_dialect/megatron/__init__.py b/slapo/framework_dialect/megatron/__init__.py similarity index 100% rename from slapo/model_dialect/megatron/__init__.py rename to slapo/framework_dialect/megatron/__init__.py diff --git a/slapo/model_dialect/megatron/utils.py b/slapo/framework_dialect/megatron/utils.py similarity index 96% rename from slapo/model_dialect/megatron/utils.py rename to slapo/framework_dialect/megatron/utils.py index a69ef3ee..f76f8840 100644 --- a/slapo/model_dialect/megatron/utils.py +++ b/slapo/framework_dialect/megatron/utils.py @@ -3,10 +3,10 @@ import re -from ..registry import register_model_dialect +from ..registry import register_framework_dialect -@register_model_dialect("megatron", "log_parser") +@register_framework_dialect("megatron", "log_parser") class MegatronLogParser: @staticmethod def parse_log(log_filename): diff --git a/slapo/model_dialect/registry.py b/slapo/framework_dialect/registry.py similarity index 77% rename from slapo/model_dialect/registry.py rename to slapo/framework_dialect/registry.py index 1bbdf9aa..0c7b1606 100644 --- a/slapo/model_dialect/registry.py +++ b/slapo/framework_dialect/registry.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Framework model dialect registration.""" +"""Framework dialect registration.""" DIALECTS = { "pipeline_stage": {}, @@ -10,15 +10,15 @@ } -def register_model_dialect(target, cls_type): - """Register a framework model dialect.""" +def register_framework_dialect(target, cls_type): + """Register a framework dialect.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") def decorator(dialect_cls): if "target" in DIALECTS[cls_type]: raise ValueError( - f"Target {target} already registered for {cls_type} model dialects" + f"Target {target} already registered for {cls_type} dialects" ) DIALECTS[cls_type][target] = dialect_cls return dialect_cls @@ -27,14 +27,14 @@ def decorator(dialect_cls): def get_all_dialects(cls_type): - """Get all registered framework model dialects.""" + """Get all registered framework dialects.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") return DIALECTS[cls_type] def get_dialect_cls(cls_type, target, allow_none=False): - """Get the framework model dialect class.""" + """Get the framework dialect class.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") if target not in DIALECTS[cls_type]: @@ -46,7 +46,5 @@ def get_dialect_cls(cls_type, target, allow_none=False): f"Target {target} does not register default dialect for {cls_type}" ) else: - raise ValueError( - f"Target {target} not registered for {cls_type} model dialects" - ) + raise ValueError(f"Target {target} not registered for {cls_type} dialects") return DIALECTS[cls_type][target] diff --git a/slapo/pipeline.py b/slapo/pipeline.py index 26fc3767..1a8aab6c 100644 --- a/slapo/pipeline.py +++ b/slapo/pipeline.py @@ -8,7 +8,7 @@ from torch.fx.passes.split_module import split_module from .logger import get_logger -from .model_dialect import get_dialect_cls +from .framework_dialect import get_dialect_cls from .utils.common import transfer_hooks logger = get_logger() diff --git a/slapo/schedule.py b/slapo/schedule.py index 6766adde..2ee90879 100644 --- a/slapo/schedule.py +++ b/slapo/schedule.py @@ -22,7 +22,7 @@ from .checkpoint import checkpoint as checkpoint_module from .logger import get_logger -from .model_dialect import get_dialect_cls +from .framework_dialect import get_dialect_cls from .pipeline import ( analyze_tie_weights, build_pipeline_model, diff --git a/tests/end2end.py b/tests/end2end.py index a8cb3b02..303c7edb 100644 --- a/tests/end2end.py +++ b/tests/end2end.py @@ -8,7 +8,7 @@ import os import pytest -from slapo.model_dialect import get_dialect_cls +from slapo.framework_dialect import get_dialect_cls def parse_log(impl, log_file): diff --git a/tests/test_pipeline_partition.py b/tests/test_pipeline_partition.py index 464af90c..e00517d3 100644 --- a/tests/test_pipeline_partition.py +++ b/tests/test_pipeline_partition.py @@ -107,7 +107,10 @@ def forward(self, x): model.stage2.linear.weight = model.stage1.linear.weight tie_weights = list(slapo.pipeline.analyze_tie_weights(model, True).values()) - tie_ranks, tie_stages = slapo.model_dialect.deepspeed.pipeline.analyze_tie_ranks( + ( + tie_ranks, + tie_stages, + ) = slapo.framework_dialect.deepspeed.pipeline.analyze_tie_ranks( tie_weights, topology )