-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Configuration setup for SageMaker #17
Conversation
from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool | ||
|
||
|
||
def get_cluster_input(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decoupled input gathering from main config get_user_input()
into the separate options
|
||
|
||
def get_user_input(): | ||
compute_environment = _ask_field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Determines which config flow should be used.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create a separate file to merge all configuration dataclasses
. We can also think about moving them so state.py
default_config_file = default_json_config_file | ||
|
||
|
||
def load_config_from_file(config_file): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general loading function for determining if config file is json
or yaml
and which configuration is used for loading. Open for suggestions on how to identify the config_class
easier.
|
||
|
||
@dataclass | ||
class BaseConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base class contains overlapping properties and methods.
class ClusterConfig(BaseConfig): | ||
num_processes: int | ||
machine_rank: int = 0 | ||
num_machines: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kept num_machines
in specialist class due to dataclass
-> Dataclass fields without default value cannot appear after data fields with default values
.
src/accelerate/state.py
Outdated
""" | ||
|
||
# Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box. | ||
CUSTOM_CLUSTER = "CUSTOM_CLUSTER" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
open for better naming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think LOCAL_MACHINE
is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work! To complete the decoupling of commands.config
while leaving stuff easily accessible, I would put all the main functions/classes of commands.config.xxx.py
inside the commands.config.__init__.py
.
Then I think LOCAL_MACHINE is a better name than CUSTOM_CLUSTER which sounds a bit too grand for most users ;-)
Thanks for adding a basic CI, it was on my TODO but I was too lazy to actually do it ;-)
@@ -16,7 +16,7 @@ | |||
|
|||
from argparse import ArgumentParser | |||
|
|||
from accelerate.commands.config import config_command_parser | |||
from accelerate.commands.config.config import config_command_parser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit weird so would remove one .config by adding config_command_parser
in the intermediate init.
|
||
def get_user_input(): | ||
compute_environment = _ask_field( | ||
"In which compute environment are you running? ([0] Custom Cluster, [1] AWS (Amazon SageMaker)): ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Custom Cluster does sound a bit too complicated for the base config.
"In which compute environment are you running? ([0] Custom Cluster, [1] AWS (Amazon SageMaker)): ", | |
"In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): ", |
yaml.safe_dump(self.to_dict(), f) | ||
|
||
def __post_init__(self): | ||
if isinstance(self.distributed_type, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(self.distributed_type, str): | |
if isinstance(self.compute_environment, str): |
src/accelerate/commands/launch.py
Outdated
@@ -23,8 +23,8 @@ | |||
from pathlib import Path | |||
from typing import Optional | |||
|
|||
from accelerate.commands.config import LaunchConfig, default_config_file | |||
from accelerate.state import DistributedType | |||
from accelerate.commands.config.config_args import default_config_file, load_config_from_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also put those two functions in the intermediate init accelerate.commands.config.__init__.py
src/accelerate/state.py
Outdated
""" | ||
|
||
# Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box. | ||
CUSTOM_CLUSTER = "CUSTOM_CLUSTER" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think LOCAL_MACHINE
is better.
src/accelerate/utils.py
Outdated
_has_boto3 = importlib.util.find_spec("boto3") is not None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would define a public is_boto3_available()
function instead of a private variable.
extras["test"] = [ | ||
"pytest", | ||
"pytest-xdist", | ||
] | ||
setup( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can add an extra "sagemaker" with boto3 inside.
I can´t reproduce why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, LGTM! Thanks for working on this.
The quality issue may originate from mismatched version between your versions and the CI?
This PR adds configuration possibilities for launch scripts on Amazon SageMaker. This pr contains only the configuration part of the CLI and not the job launch itself. I decoupled the
config
subparsers into multiple small files to reduce complexity and structure different config options clear. This allows us to add in the future different configurations easier.I also added 2 github actions which run
make quality
andmake test
.I know it is a big PR. I hope we can iterate fast on it.