Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add init_system API for multi-host GPU. #8364

Merged
merged 1 commit into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).

* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
Expand Down
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
# jax and rely on the names imported above.
from . import abstract_arrays as abstract_arrays
from . import api_util as api_util
from . import distributed as distributed
from . import dtypes as dtypes
from . import errors as errors
from . import image as image
Expand Down
59 changes: 59 additions & 0 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

from absl import logging
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension

_service = None
def initialize(coordinator_address: str, num_processes: int, process_id: int):
"""Initialize distributed system for topology discovery.

Currently, calling ``initialize`` sets up the multi-host GPU backend, and
is not required for CPU or TPU backends.

Args:
coordinator_address: IP address of the coordinator.
num_processes: Number of processes.
process_id: Id of the current processe.

Example:

Suppose there are two GPU hosts, and host 0 is the designated coordinator
with address '10.0.0.1:1234', to initialize the GPU cluster, run the
following commands before anything else.

On host 0
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP

On host 1
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
"""
if process_id == 0:
global _service
assert _service is None, 'initialize should be called once only'
logging.info('Starting JAX distributed service on %s', coordinator_address)
_service = xla_extension.get_distributed_runtime_service(coordinator_address,
num_processes)

client = xla_extension.get_distributed_runtime_client(coordinator_address,
process_id)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
client.connect()

factory = functools.partial(xla_client.make_gpu_client, client, process_id)
xla_bridge.register_backend_factory('gpu', factory, priority=300)
12 changes: 7 additions & 5 deletions jax/_src/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,15 @@ def _log_warning():
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
_default_backend = None
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()

def register_backend_factory(name, factory, *, priority=0):
with _backend_lock:
if name in _backends:
raise RuntimeError(f"Backend {name} already initialized")
_backend_factories[name] = (factory, priority)


Expand All @@ -187,11 +194,6 @@ def register_backend_factory(name, factory, *, priority=0):
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)

_default_backend = None
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()


def backends():
global _backends
Expand Down
16 changes: 16 additions & 0 deletions jax/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.distributed import initialize