Skip to content

Commit

Permalink
[T5X] Improve support for GPU, including multi-host GPU.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 458015130
  • Loading branch information
jekbradbury authored and t5-copybara committed Jun 29, 2022
1 parent 045a92a commit 6eaa981
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
34 changes: 24 additions & 10 deletions t5x/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax import random
from jax.experimental import PartitionSpec
from jax.experimental.maps import Mesh
from jax.experimental.mesh_utils import create_hybrid_device_mesh
from jax.experimental.pjit import pjit as jax_pjit
import numpy as np
from t5x import train_state as train_state_lib
Expand Down Expand Up @@ -269,7 +270,7 @@ def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:
# reshape to (data, model)
devices = devices.reshape(-1, np.prod(model_parallel_submesh))
global_mesh = Mesh(devices, ['data', 'model'])
logging.info('global_mesh axes_names: %s', global_mesh.axis_names)
logging.info('global_mesh axis_names: %s', global_mesh.axis_names)
logging.info('global_mesh devices: %s', global_mesh.devices)
return global_mesh

Expand All @@ -283,13 +284,26 @@ def get_cpu_mesh() -> Mesh:
return Mesh(devices, ['data', 'model'])


def get_gpu_mesh() -> Mesh:
"""Simple mesh for GPUs."""
devices = np.empty((jax.host_count(), jax.local_device_count()),
dtype=np.object)
for device in jax.devices():
devices[device.process_index, device.id % jax.local_device_count()] = device
return Mesh(devices, ['data', 'model'])
def get_gpu_mesh(num_partitions: int) -> Mesh:
"""Mesh for GPUs that preferentially places 'model' on NVLink."""
nvlink_size = jax.local_device_count()
dcn_size = jax.process_count()
nvlink_mp = min(num_partitions, nvlink_size)
nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
assert not (extra1 or extra2), ('number of partitions on GPU must be a factor'
' or multiple of the number of local devices')
dcn_dp = dcn_size // dcn_mp

devices = create_hybrid_device_mesh(
mesh_shape=[nvlink_dp, nvlink_mp],
dcn_mesh_shape=[dcn_dp, dcn_mp],
process_is_granule=True)

global_mesh = Mesh(devices, ['data', 'model'])
logging.info('global_mesh axis_names: %s', global_mesh.axis_names)
logging.info('global_mesh devices: %s', global_mesh.devices)
return global_mesh


def default_mesh(num_partitions: int,
Expand Down Expand Up @@ -320,7 +334,7 @@ def default_mesh(num_partitions: int,
if platform == 'cpu':
return get_cpu_mesh()
elif platform == 'gpu':
return get_gpu_mesh()
return get_gpu_mesh(num_partitions)

mps = None
if device_kind in ('TPU v2', 'TPU v3'):
Expand Down Expand Up @@ -703,7 +717,7 @@ def partition(
`PartitionSpec`: a tuple of length at most equal to the rank of the
partitioned value. Each element can be a `None`, a mesh axis or a
tuple of mesh axes, and specifies the set of resources assigned to
partition the value's dimension matching its position in the spec.
partition the value's dimension matching its position in the spec.
out_axis_resources: Like `in_axis_resources`, but specifies resource
assignment for function outputs.
static_argnums: an optional int or collection of ints that specify which
Expand Down
31 changes: 31 additions & 0 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,22 @@ def _run_inference_eval():
'seqio_additional_cache_dirs', [],
'Directories to search for cached Tasks in addition to defaults.')

flags.DEFINE_boolean(
'multiprocess_gpu',
False,
help='Initialize JAX distributed system for multi-host GPU, using '
'`coordinator_address`, `process_count`, and `process_index`.')

flags.DEFINE_string(
'coordinator_address',
None,
help='IP address:port for multi-host GPU coordinator.')

flags.DEFINE_integer(
'process_count', None, help='Number of processes for multi-host GPU.')

flags.DEFINE_integer('process_index', None, help='Index of this process.')



def main(argv: Sequence[str]):
Expand All @@ -671,6 +687,21 @@ def _main(argv: Sequence[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

if FLAGS.multiprocess_gpu:
if (FLAGS.coordinator_address is None or FLAGS.process_count is None or
FLAGS.process_index is None):
raise ValueError(
'`coordinator_address`, `process_count` and `process_index` '
'must be provided alongside `multiprocess_gpu`')

logging.info(
'Initializing distributed system for multi-host GPU:\n'
' coordinator_address: %s\n process_count: %s\n process_index: %s',
FLAGS.coordinator_address, FLAGS.process_count, FLAGS.process_index)

jax.distributed.initialize(FLAGS.coordinator_address, FLAGS.process_count,
FLAGS.process_index)

if FLAGS.tfds_data_dir:
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)

Expand Down

0 comments on commit 6eaa981

Please sign in to comment.