Skip to content

Commit 81aeea3

Browse files
jeffrasamyam
andauthored
Elastic training support (#602)
Co-authored-by: Samyam Rajbhandari <[email protected]>
1 parent 7435b2f commit 81aeea3

File tree

16 files changed

+883
-22
lines changed

16 files changed

+883
-22
lines changed

.github/workflows/main.yml

+4-6
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ name: Build
44

55
# Controls when the action will run.
66
on:
7-
# Triggers the workflow on push or pull request events but only for the master branch
87
push:
9-
branches: [ master ]
8+
paths-ignore:
9+
- 'docs/**'
1010
pull_request:
11-
branches: [ master ]
12-
13-
# Allows you to run this workflow manually from the Actions tab
14-
workflow_dispatch:
11+
paths-ignore:
12+
- 'docs/**'
1513

1614
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
1715
jobs:

bin/ds_elastic

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import json
5+
6+
import deepspeed
7+
from deepspeed.elasticity import compute_elastic_config
8+
9+
10+
if __name__ == '__main__':
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
13+
parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
14+
args = parser.parse_args()
15+
ds_config = json.load(open(args.config, 'r'))
16+
17+
ds_version = deepspeed.__version__
18+
19+
elastic_config = ds_config['elasticity']
20+
print('------------------------------------------')
21+
print("Elasticity config:")
22+
print('------------------------------------------')
23+
print(json.dumps(elastic_config, indent=4, sort_keys=True))
24+
25+
if args.world_size > 0:
26+
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
27+
print('------------------------------------------')
28+
print(f"Calculated results for world size {args.world_size}:")
29+
print('------------------------------------------')
30+
print(f'final_batch_size .... {final_batch_size}')
31+
print(f'valid_gpus .......... {valid_gpus}')
32+
print(f'micro_batch_size .... {micro_batch_size}')
33+
else:
34+
final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
35+
print('------------------------------------------')
36+
print("Calculated results:")
37+
print('------------------------------------------')
38+
print(f'final_batch_size .... {final_batch_size}')
39+
print(f'valid_gpus .......... {valid_gpus}')

deepspeed/elasticity/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config

deepspeed/elasticity/config.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
Copyright 2020 The Microsoft DeepSpeed Team
3+
"""
4+
5+
import json
6+
from .constants import *
7+
8+
9+
class ElasticityError(Exception):
10+
"""
11+
Base exception for all elasticity related errors
12+
"""
13+
pass
14+
15+
16+
class ElasticityConfigError(ElasticityError):
17+
"""
18+
Elasticity configuration error
19+
"""
20+
pass
21+
22+
23+
class ElasticityIncompatibleWorldSize(ElasticityError):
24+
"""
25+
Attempting to run a world size that is incompatible with a given elastic config
26+
"""
27+
pass
28+
29+
30+
class ElasticityConfig:
31+
"""
32+
Elastic config object, constructed from a param dictionary that only contains elastic
33+
config parameters, example below:
34+
35+
If elasticity is enabled, user must specify (at least) max_train_batch_size
36+
and micro_batch_sizes.
37+
38+
{
39+
"enabled": true,
40+
"max_train_batch_size": 2000,
41+
"micro_batch_sizes": [2,4,6],
42+
"min_gpus": 1,
43+
"max_gpus" : 10000
44+
"min_time": 20
45+
"ignore_non_elastic_batch_info": false
46+
"version": 0.1
47+
}
48+
"""
49+
def __init__(self, param_dict):
50+
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
51+
if self.enabled:
52+
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
53+
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
54+
else:
55+
raise ElasticityConfigError(
56+
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
57+
if MICRO_BATCHES in param_dict:
58+
self.micro_batches = param_dict[MICRO_BATCHES]
59+
else:
60+
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
61+
else:
62+
self.max_acceptable_batch_size = param_dict.get(
63+
MAX_ACCEPTABLE_BATCH_SIZE,
64+
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
65+
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
66+
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
67+
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
68+
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
69+
self.version = param_dict.get(VERSION, VERSION_DEFAULT)
70+
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
71+
PREFER_LARGER_BATCH_DEFAULT)
72+
self.ignore_non_elastic_batch_info = param_dict.get(
73+
IGNORE_NON_ELASTIC_BATCH_INFO,
74+
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
75+
76+
def repr(self):
77+
return self.__dict__
78+
79+
def __repr__(self):
80+
return json.dumps(self.__dict__, sort_keys=True, indent=4)

deepspeed/elasticity/constants.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Copyright 2020 The Microsoft DeepSpeed Team
3+
"""
4+
5+
#########################################
6+
# Elasticity
7+
#########################################
8+
''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible
9+
with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that
10+
can support a large number of GPUs based on the user specified parameters
11+
'''
12+
FORMAT = '''
13+
Elasticity should be enabled as:
14+
"elasticity": {
15+
"enabled": true,
16+
"max_train_batch_size": 2000,
17+
"micro_batch_sizes": [2,4,6],
18+
"min_gpus": 1,
19+
"max_gpus" : 10000
20+
"min_time": 20,
21+
"prefer_larger_batch": true,
22+
"ignore_non_elastic_batch_info": false,
23+
"version": 0.1
24+
}
25+
'''
26+
27+
ELASTICITY = 'elasticity'
28+
29+
# Current elasticity version
30+
LATEST_ELASTICITY_VERSION = 0.1
31+
32+
ENABLED = 'enabled'
33+
ENABLED_DEFAULT = False
34+
35+
# Max acceptable train_batch_size
36+
MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size'
37+
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000
38+
39+
# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu
40+
MICRO_BATCHES = 'micro_batch_sizes'
41+
MICRO_BATCHES_DEFAULT = [2, 4, 6]
42+
43+
# Min/max of GPUs to search over
44+
MIN_GPUS = 'min_gpus'
45+
MIN_GPUS_DEFAULT = 1
46+
MAX_GPUS = 'max_gpus'
47+
MAX_GPUS_DEFAULT = 10000
48+
49+
# Minimum running time (minutes) before the scheduler will scale us
50+
MIN_TIME = "min_time"
51+
MIN_TIME_DEFAULT = "20"
52+
53+
# When finding a suitable batch size, attempt to find one that is closest
54+
# to the max train batch size given.
55+
PREFER_LARGER_BATCH = 'prefer_larger_batch'
56+
PREFER_LARGER_BATCH_DEFAULT = True
57+
58+
# In order to reduce confusion, if elastic mode is enabled we
59+
# require (via assert) that no batch info is set outside of the
60+
# elastic config. You can turn off this assert via this config
61+
# but keep in mind that all batch info defined outside the
62+
# elastic mode *will be ignored*.
63+
IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info'
64+
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False
65+
66+
# Version of elastic logic to use
67+
VERSION = "version"
68+
VERSION_DEFAULT = LATEST_ELASTICITY_VERSION
69+
70+
# Minimum deepspeed version to use elasticity
71+
MINIMUM_DEEPSPEED_VERSION = "0.3.8"
72+
73+
# Environment variable storing elastic config from resource scheduler
74+
DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG"

0 commit comments

Comments
 (0)