From af0bb86e072d2c8a814ffb285aeb3213dfe10112 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 22 Jan 2025 09:47:37 -0800 Subject: [PATCH] Merge mscclpp-lang to mscclpp project (#442) First step to merge msccl-tools into mscclpp repo. In this step will move all msccl related code, pass the current tests and do some necessary refactor. Add `mscclpp.language` module Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag Add `DagLower` to lower dag to intermediate representation Add documents for mscclpp.language Remove msccl related code --- .azure-pipelines/nccl-api-test.yaml | 17 - .github/workflows/mscclpp-lang.yml | 46 ++ README.md | 3 +- docs/design/mscclpp-dsl.md | 114 ++++ docs/index.rst | 2 + python/examples/allgather_barrier.py | 55 ++ python/examples/allreduce_allpairs.py | 65 +++ python/examples/allreduce_allpairs_get.py | 78 +++ python/examples/allreduce_allpairs_packet.py | 69 +++ python/examples/allreduce_nvls.py | 55 ++ python/examples/allreduce_ring.py | 59 ++ python/examples/send_recv_packet.py | 57 ++ python/examples/send_recv_proxy.py | 56 ++ python/mscclpp/language/__init__.py | 4 + python/mscclpp/language/buffer.py | 58 ++ python/mscclpp/language/chunk.py | 64 +++ python/mscclpp/language/collectives.py | 339 +++++++++++ python/mscclpp/language/dag/__init__.py | 6 + .../mscclpp/language/dag/instruction_dag.py | 373 ++++++++++++ python/mscclpp/language/dag/lower.py | 162 ++++++ python/mscclpp/language/dag/optimizer.py | 405 +++++++++++++ python/mscclpp/language/ir.py | 534 ++++++++++++++++++ python/mscclpp/language/program.py | 433 ++++++++++++++ python/mscclpp/language/rank.py | 33 ++ python/mscclpp/language/types.py | 173 ++++++ python/mscclpp/language/utils.py | 94 +++ .../configs/mscclpp_lang_test_config.json | 34 ++ .../test/test_generate_mscclpp_lang_result.py | 47 ++ 28 files changed, 3417 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/mscclpp-lang.yml create mode 100644 docs/design/mscclpp-dsl.md create mode 100644 python/examples/allgather_barrier.py create mode 100644 python/examples/allreduce_allpairs.py create mode 100644 python/examples/allreduce_allpairs_get.py create mode 100644 python/examples/allreduce_allpairs_packet.py create mode 100644 python/examples/allreduce_nvls.py create mode 100644 python/examples/allreduce_ring.py create mode 100644 python/examples/send_recv_packet.py create mode 100644 python/examples/send_recv_proxy.py create mode 100644 python/mscclpp/language/__init__.py create mode 100644 python/mscclpp/language/buffer.py create mode 100644 python/mscclpp/language/chunk.py create mode 100644 python/mscclpp/language/collectives.py create mode 100644 python/mscclpp/language/dag/__init__.py create mode 100644 python/mscclpp/language/dag/instruction_dag.py create mode 100644 python/mscclpp/language/dag/lower.py create mode 100644 python/mscclpp/language/dag/optimizer.py create mode 100644 python/mscclpp/language/ir.py create mode 100644 python/mscclpp/language/program.py create mode 100644 python/mscclpp/language/rank.py create mode 100644 python/mscclpp/language/types.py create mode 100644 python/mscclpp/language/utils.py create mode 100644 python/test/configs/mscclpp_lang_test_config.json create mode 100644 python/test/test_generate_mscclpp_lang_result.py diff --git a/.azure-pipelines/nccl-api-test.yaml b/.azure-pipelines/nccl-api-test.yaml index 92b2d0d9a..33ae4c09d 100644 --- a/.azure-pipelines/nccl-api-test.yaml +++ b/.azure-pipelines/nccl-api-test.yaml @@ -87,23 +87,6 @@ jobs: parallel-scp -t 0 -r -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION ${ROOT_DIR} ${DST_DIR} workingDirectory: '$(System.DefaultWorkingDirectory)' - - task: Bash@3 - name: InstallMscclTools - displayName: Install msccl-tools - inputs: - targetType: 'inline' - script: | - set -e - HOSTFILE=$(System.DefaultWorkingDirectory)/mscclpp/test/deploy/hostfile_ci - SSH_OPTION="StrictHostKeyChecking=no" - KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} - parallel-ssh -i -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" \ - -O $SSH_OPTION 'sudo docker exec -t mscclpp-test bash -c " \ - cd /root/mscclpp; \ - git clone https://github.com/Azure/msccl-tools.git; \ - cd /root/mscclpp/msccl-tools; pip3 install ."' - workingDirectory: '$(System.DefaultWorkingDirectory)' - - task: Bash@3 name: GenerateExecutionFile displayName: Generate execution file diff --git a/.github/workflows/mscclpp-lang.yml b/.github/workflows/mscclpp-lang.yml new file mode 100644 index 000000000..a83b8475f --- /dev/null +++ b/.github/workflows/mscclpp-lang.yml @@ -0,0 +1,46 @@ +name: MSCCLPPLang + +on: + pull_request: + branches: + - main + - release/* + +jobs: + compare-diffs: + runs-on: 'ubuntu-latest' + container: + image: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-${{ matrix.version }} + + strategy: + fail-fast: false + matrix: + version: [ 'cuda11.8', 'cuda12.2' ] + + steps: + - uses: actions/checkout@v4 + - name: Install mscclpp + run: | + CMAKE_ARGS="-DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON" pip3 install . + + - name: Copy test script/config to temp directory + run: | + cp python/test/test_generate_mscclpp_lang_result.py $RUNNER_TEMP/ + cp python/test/configs/mscclpp_lang_test_config.json $RUNNER_TEMP/ + - name: generate outputs + run: | + python3 $RUNNER_TEMP/test_generate_mscclpp_lang_result.py python/examples/ $RUNNER_TEMP/mscclpp_lang_test_config.json $RUNNER_TEMP/tests/pr-outputs/ + - name: Checkout main branch + uses: actions/checkout@v4 + if: github.event_name == 'pull_request' || github.event_name == 'push' + with: + ref: main + - name: Install msccl and dependencies + run: | + CMAKE_ARGS="-DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON" pip3 install . + - name: generate outputs + run: | + python3 $RUNNER_TEMP/test_generate_mscclpp_lang_result.py python/examples/ $RUNNER_TEMP/mscclpp_lang_test_config.json $RUNNER_TEMP/tests/main-outputs/ + - name: Compare outputs + run: | + diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ \ No newline at end of file diff --git a/README.md b/README.md index 917a437db..4127f8b8e 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![Latest Release](https://img.shields.io/github/release/microsoft/mscclpp.svg)](https://github.com/microsoft/mscclpp/releases/latest) [![License](https://img.shields.io/github/license/microsoft/mscclpp.svg)](LICENSE) [![CodeQL](https://github.com/microsoft/mscclpp/actions/workflows/codeql-analysis.yml/badge.svg?branch=main)](https://github.com/microsoft/mscclpp/actions/workflows/codeql-analysis.yml) +[![Docs Build](https://github.com/microsoft/mscclpp/actions/workflows/doc-build.yaml/badge.svg)](https://microsoft.github.io/mscclpp/) | Pipelines | Build Status | |--------------------------|-------------------| @@ -12,7 +13,7 @@ A GPU-driven communication stack for scalable AI applications. -See [Quick Start](docs/getting-started/quickstart.md) to quickly get started. +See [Quick Start](https://microsoft.github.io/mscclpp/getting-started/quickstart.html) to quickly get started. ## Overview diff --git a/docs/design/mscclpp-dsl.md b/docs/design/mscclpp-dsl.md new file mode 100644 index 000000000..9b34b29f0 --- /dev/null +++ b/docs/design/mscclpp-dsl.md @@ -0,0 +1,114 @@ +# MSCCL++ DSL +## MSCCLPPLang Introduction +MSCCLPPLang is a Python moudule for writing high-performance commnunication algorithms. It is designed to be easy to use and efficient, while providing a high-level interface for writing communication algorithms. MSCCLPPLang program will be compiled to json based execution plan, which can be executed by MSCCL++ executor. + +## How to use MSCCLPPLang +### Install mscclpp package +```bash +git clone https://github.com/microsoft/mscclpp.git +cd mscclpp +pip install . +``` + +### Import mscclpp language module +```python +import mscclpp.language * +from mscclpp.language.types import ChannelType, ReplicationPolicy +from mscclpp.language.collectives import AllGather + +instances = 1 +size = gpus +collective = AllGather(size, chunk_factor=1, inplace=True) +with MSCCLPPProgram( + "allgather", + collective, + size, + instances, + protocol="Simple", + replication_policy=ReplicationPolicy.interleaved, +): + pass +``` + +## How MSCCLPPLang Works +MSCCLPPLang provides a high-level interface for writing communication algorithms. We treat the communication algorithm as a graph, where the nodes are the data and the edges are the communication operations. The graph is represented as a Python program, which is compiled to a json based execution plan. + +### Core Concepts + +#### MSCCLPPProgram +A MSCCLPPProgram provides the context to write MSCCLPPLang program, which can be initialized with `with` statement in Python. Its parameters include: + +- `name`: Name of this program. +- `collective`: Collective type of this program, should be from `mscclpp.language.collectives`. +- `instances`: Number of parallel instances of this program. Please see the [Instance](#instance) section for more details. +- `protocol`: Data transmission protocol used in this program, can be `LL` or `Simple`. Optional, default is `Simple`. +- `instr_fusion`: Whether low-level instruction fusion is enabled. Optional, default is `True`. +- `replication_policy`: Data replication policy, should be from `mscclpp.language.types.ReplicationPolicy`. Optional, default is `duplicated`. Please see the [Instance](#instance) section for more details. +- `num_threads_per_block`: Thread block size. Optional, default is `1024`. +- `use_double_scratch_buffer`: Whether requires double scratch buffer during execution. Optional, default is `False`. + +### Collective: +A collective is a communication operation that involves multiple GPUs. We provide a set of collective operations for users to utilize. For example, the `AllGather` operation gathers data from all GPUs to all GPUs. To instantiate a collective, the user needs to specify the number of ranks, the chunk factor (how many chunks the input buffer will be split into), and whether the operation is in-place. + +#### Chunk +A chunk is a piece of data that is sent between GPUs. It is the basic unit of data in MSCCLPPLang. Chunk can be a piece of data from input buffer, output buffer or intermediate buffer. +Example of creating a chunk: +```python +c = chunk(rank, Buffer.input, index, size) +``` +- rank: the rank of the GPU that the chunk belongs to. +- buffer: the buffer that the chunk belongs to. It can be Buffer.input, Buffer.output or Buffer.scratch. +- index: the index of the chunk in the buffer. +- size: the number of unit chunks. + +Assume we split the input data in the buffer into 4 chunks. On GPU rank 0, we can retrieve the chunks from indices 0 to 2 using the following command: +```python +c = chunk(0, Buffer.input, 0, 2) +``` + +#### Operation +The operation can only be applied to the chunks. We provide a set of communications operations for the users to use. For example, the `put` operation is used to send the data from one GPU to another GPU. The `get` operation is used to receive the data from another GPU. + +***Please notice***: MSCCLPPLang only provides one-sided communication operations. The user needs to make sure that the data is ready to be sent or received before calling the communication operations. Also we provides `wait/signal` operations to synchronize the communication across GPUs. + +#### Channel +A channel is a communication channel between two GPUs. It is used to send and receive data between GPUs. We supports three types of channel: `ChannelType.sm`, `ChannelType.proxy` and `ChannelType.nvls`. + +`ChannelType.sm` is used for communication between GPUs on the same node. This channel uses GPU processors to transfer data. + +`ChannelType.proxy` is used for communication between GPUs, whether they are on different nodes or the same node. This channel will offload the data transfer to CPU processors, which can provide better throughput compared to `ChannelType.sm`. However, this comes at the cost of higher latency compared to `ChannelType.sm`. + +`ChannelType.nvls` is used for communication between GPUs on the same node. This feature offloads the data processing task to the switch, requiring specific hardware support. Refer [nvdia documentation](https://www.nvidia.com/en-us/data-center/nvlink/) for more details. + +#### Thread Block +We can assign operations to a thread block. The thread block is a group of threads that are executed together on the GPU. In the operation function, we can specify the thread block that the operation belongs to via `sendtb` or `recvtb` parameter. + +#### Instance +An instance is a parallel execution of the program. For example, if a collective algorithm is designed to run on `n` chunks with `m` thread blocks, setting the instance to 2 will run the algorithm on `2n` chunks with `2m` thread blocks. Serveral replication policies are supported, including `duplicated` and `interleaved`. +- `duplicated`: Each chunk is split into smaller parts based on the number of instances, duplicating the same instructions for all parts. For example, ChunkA is split into ChunkA0 and ChunkA1, while ChunkB is split into ChunkB0 and ChunkB1. Both ChunkA0 and ChunkA1 belong to Instance 0, and both ChunkB0 and ChunkB1 belong to Instance 1. +- `interleaved`: Assign chunks to instances in an interleaved manner. For example, ChunkA and ChunkB are split into to ChunkA0, ChunkA1, ChunkB0, and ChunkB1. ChunkA0 and ChunkB0 belong to Instance 0, while ChunkA1 and ChunkB1 belong to Instance 1. + +#### Instruction Fusion +MSCCLPPLang provides the instruction fusion mechanism to fuse multiple operations into a single kernel. This can reduce the overhead of launching multiple instructions. When users create the MSCCLPPLang program, they can specify the `instr_fusion` parameter to enable the instruction fusion. By default, the instruction fusion is enabled. + +## MSCCLPPLang APIs + +### Basic APIs +- `chunk(rank, buffer, index, size)`: create a chunk. +- `put(self, dst, buffer, index, sendtb, chan_type)`: send the data from one GPU to another GPU. User can specify the index of the chunk in the destination buffer, the sendtb and the channel type. +- `get(self, src, buffer, index, recvtb, chan_type)`: receive the data from another GPU. User can specify the index of the chunk in the destination buffer, the recvtb and the channel type. +- `signal(self, dst, buffer, index, sendtb, chan_type)`: send a signal to another GPU. +- `wait(self, src, buffer, index, recvtb, chan_type)`: wait for a signal from another GPU. +- `flush(self, dst, buffer, index, sendtb, chan_type)`: flush the data in the buffer to the destination GPU. This is used to make sure the data is sent to the destination GPU. +- `copy(self, dst, buffer, index, sendtb)`: copy the data from one buffer to another buffer in the same GPU. +- `reduce(self, other_chunkref, recvtb, channel_type)`: Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + +### Packet APIs +Packet APIs are used when user wants to use LL algorithm. The packet APIs are similar to the basic APIs, it will packet the data and flags into a packet and send the packet to the destination GPU. The destination GPU will unpack the packet and get the data and flags. So no synchronization is needed when using packet APIs. (`ChannelType.nvls` does not support packet APIs) +- `packet_put(self, dst, buffer, index, sendtb, chan_type)`: send the data from one GPU to another GPU using packet. +- `copy_packet(self, dst, buffer, index, sendtb)`: copy the data from one buffer to another buffer in the same GPU using packet. +- `reduce_packet(self, other_chunkref, recvtb)`: Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref using packet. + + +### Examples +We provide several examples demonstrating how to use the MSCCL++ DSL to write communication collective algorithms. For more details, please refer to the [examples](https://github.com/microsoft/mscclpp/tree/main/mscclpp-lang/python/examples) folder. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index dc5604364..36068bfb1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ Design ------- - :doc:`Design ` doc for those who want to understand the internals of MSCCL++. - :doc:`NCCL over MSCCL++ ` doc for those who want to understand how to use NCCL over MSCCL++. +- :doc:`MSCCL++ DSL ` doc for those who want to understand the MSCCL++ DSL. .. toctree:: :maxdepth: 1 @@ -33,6 +34,7 @@ Design design/design design/nccl-over-mscclpp + design/mscclpp-dsl Performance --------------- diff --git a/python/examples/allgather_barrier.py b/python/examples/allgather_barrier.py new file mode 100644 index 000000000..acc0c2a2f --- /dev/null +++ b/python/examples/allgather_barrier.py @@ -0,0 +1,55 @@ +import argparse +from mscclpp.language import * +from mscclpp.language.buffer import Buffer +from mscclpp.language.collectives import AllGather +from mscclpp.language.types import ChannelType, ReplicationPolicy + + +def allgather_test(gpus, instances): + """ + Demonstrates how to use barrier in the MSCCL++ DSL with an allgather collective. + This example uses an allpairs algorithm for the allgather operation. + Steps: + 1. Each rank sends a chunk to all other ranks' output buffers and copies the chunk to its own output buffer. + 2. A barrier is called to synchronize the send and copy operations, and signal peers that the data has been sent. + 3. Wait for all the chunks from other ranks to be received. + """ + size = gpus + collective = AllGather(size, 1, False) + with MSCCLPPProgram( + "allgather_with_barrier", + collective, + size, + instances, + protocol="Simple", + replication_policy=ReplicationPolicy.interleaved, + ): + for n in range(gpus): + c = chunk(n, Buffer.input, 0, 1) + for peer in range(gpus): + if n != peer: + c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm) + else: + c.copy(n, Buffer.output, n, sendtb=peer) + # explicit barrier + r = rank(n) + r.barrier(tb_list=list(range(gpus))) + for peer in range(gpus): + if n != peer: + c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm) + + for n in range(gpus): + for peer in range(gpus): + c = chunk(n, Buffer.output, peer, 1) + if n != peer: + c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +args = parser.parse_args() +allgather_test(args.num_gpus, args.instances) diff --git a/python/examples/allreduce_allpairs.py b/python/examples/allreduce_allpairs.py new file mode 100644 index 000000000..100e00524 --- /dev/null +++ b/python/examples/allreduce_allpairs.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.collectives import AllReduce +from mscclpp.language.buffer import Buffer + + +def allreduce_allpairs(gpus, instances, protocol): + """ + Demonstrate allreduce with all pairs algorithm using put semantics. + Steps: + 1. Sync all ranks to ensure the data is ready. + 2. Each rank reads chunks from all peers and reduces the data. + 3. Put the reduced data to all peers. + 4. Sync all ranks to ensure the data is received. + """ + size = gpus + chunksperloop = gpus * gpus + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram("allreduce_pairs", collective, size, instances, protocol=protocol): + for rank in range(size): + for tb in range(size): + index = rank * size + c = chunk(rank, Buffer.input, index + tb) + # step1 make sure the data is ready + for nghr in range(size): + peer_index = nghr * size + if rank != nghr: + # signal peer the buffer is ready + c_peer = chunk(rank, Buffer.input, peer_index + tb) + c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) + for nghr in range(size): + if rank != nghr: + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + # step2 reduce the chunks and send to peers + for nghr in range(size): + if rank != nghr: + c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.put(nghr, Buffer.input, index + tb, sendtb=tb) + # step3 signal the peers buffer is ready + for nghr in range(size): + if rank != nghr: + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + for nghr in range(size): + if rank != nghr: + peer_index = nghr * size + c_peer = chunk(rank, Buffer.input, peer_index + tb) + c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances, args.protocol) diff --git a/python/examples/allreduce_allpairs_get.py b/python/examples/allreduce_allpairs_get.py new file mode 100644 index 000000000..39e68792a --- /dev/null +++ b/python/examples/allreduce_allpairs_get.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.collectives import AllReduce +from mscclpp.language.buffer import Buffer + + +def allreduce_allpairs(gpus, instances): + """ + AllReduce with all pairs algorithm using get semantics. + Steps: + 1. Sync all ranks to ensure the data is ready. + 2. Each rank read chunks from all peers and reduces the data. + 3. Signal all ranks to notify that the data is ready. + 4. Wait for all chunks to be ready, then retrieve the chunks from all peers. + """ + size = gpus + chunksperloop = gpus * gpus + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram( + "allreduce_pairs", + collective, + size, + instances, + protocol="Simple", + ): + + # Each rank sends the nth chunk to the nth rank into scratch space + for rank in range(size): + for tb in range(size): + index = rank * size + c = chunk(rank, Buffer.input, index + tb) + # make sure the data is ready + for nghr in range(size): + peer_index = nghr * size + if rank != nghr: + c_peer = chunk(rank, Buffer.input, peer_index + tb) + c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) + for nghr in range(size): + if rank != nghr: + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + # reduce the chunks + for i in range(size): + nghr = (rank + i) % size + if rank != nghr: + c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) + for nghr in range(size): + if rank != nghr: + c.signal(nghr, Buffer.input, index + tb, sendtb=tb) + + # wait for all the chunks is ready, then get the chunks + for rank in range(size): + for tb in range(size): + for nghr in range(size): + if rank != nghr: + index = nghr * size + c = chunk(rank, Buffer.input, index + tb) + c.wait(nghr, Buffer.input, index + tb, recvtb=tb) + for i in range(size): + nghr = (rank + i) % size + index = nghr * size + if rank != nghr: + c = chunk(rank, Buffer.input, index + tb) + c.get(nghr, Buffer.input, index + tb, recvtb=tb) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances) diff --git a/python/examples/allreduce_allpairs_packet.py b/python/examples/allreduce_allpairs_packet.py new file mode 100644 index 000000000..db35565b0 --- /dev/null +++ b/python/examples/allreduce_allpairs_packet.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.collectives import AllReduce +from mscclpp.language.buffer import Buffer + + +def allreduce_allpairs(gpus, instances): + """ + AllReduce with all pairs algorithm using packets format. + Steps: + 1. Each rank sends its nth chunk to the nth rank's scratch space. + 2. Each rank performs a local reduction on its nth chunk using data from all other ranks' scratch spaces. + 3. Each rank sends the reduced data to all other ranks' scratch spaces. + 4. Each rank retrieves the final reduced result from the scratch space. + """ + size = gpus + chunksperloop = gpus * gpus + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram( + "allreduce_packets", + collective, + size, + instances, + protocol="LL", + use_double_scratch_buffer=True, + ): + # Each rank sends the nth chunk to the nth rank into scratch space + for r1 in range(size): + for tb in range(size): + if tb == r1: + continue + remote_rank = tb + index = remote_rank * size + c = chunk(r1, Buffer.input, index, size) + c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb) + + # Each rank performs a local reduction on the nth chunk + # Utilize 8 threadblocks for this reduction for better parallelism + for r in range(size): + for index in range(size): + c = chunk(r, Buffer.input, r * size + index) + for peer in range(size): + if peer != r: + c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index) + for peer in range(size): + if peer != r: + c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index) + + # Each rank get final result from scratch space + for r in range(size): + for peer in range(size): + if peer != r: + c = chunk(r, "scratch", size * size + peer * size, size) + c.copy_packet(r, Buffer.input, peer * size, sendtb=peer) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") + +args = parser.parse_args() + +allreduce_allpairs(args.num_gpus, args.instances) diff --git a/python/examples/allreduce_nvls.py b/python/examples/allreduce_nvls.py new file mode 100644 index 000000000..7f09bbc93 --- /dev/null +++ b/python/examples/allreduce_nvls.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.collectives import AllReduce +from mscclpp.language.buffer import Buffer + + +def allreduce_nvls(gpus, instances): + """ + Allreduce via NVLS channel + Steps: + 1. Sync all the ranks to make sure the data is ready. + 2. Call group_load_reduce to reduce the data. + 3. Call group_store to propagate the data to all the ranks. + """ + size = gpus + chunksperloop = gpus + collective = AllReduce(size, chunksperloop, True) + with MSCCLPPProgram( + "allreduce_nvls", + collective, + size, + instances, + ): + # Each rank sends the nth chunk to the nth rank into scratch space + for rank in range(size): + index = rank + c = chunk(rank, Buffer.input, index) + reduce_chunks = [] + # make sure the data is ready + for nghr in range(size): + if rank != nghr: + c_peer = chunk(nghr, Buffer.input, index) + reduce_chunks.append(c_peer) + c.signal(nghr, Buffer.input, index, sendtb=0) + for nghr in range(size): + if rank != nghr: + c.wait(nghr, Buffer.input, index, recvtb=0) + c = c.group_load_reduce(reduce_chunks, recvtb=0) + ngbrs = [nghr for nghr in range(size) if nghr != rank] + c.group_store(ngbrs, sendtb=0) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") + +args = parser.parse_args() + +allreduce_nvls(args.num_gpus, args.instances) diff --git a/python/examples/allreduce_ring.py b/python/examples/allreduce_ring.py new file mode 100644 index 000000000..15ae24e6a --- /dev/null +++ b/python/examples/allreduce_ring.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.collectives import AllReduce +from mscclpp.language.buffer import Buffer + + +def allreduce_ring(size, instances): + """ + Implements a ring based allreduce. + Steps: + 1. Send signal to next rank and wait for signal from previous rank. Make sure the data is ready in previous rank. + 2. Reduce the data and send to next rank. + 3. After all the data is reduced, propagate the data to all the ranks. + """ + collective = AllReduce(size, size, True) + with MSCCLPPProgram( + f"allreduce_ring", + collective, + size, + instances, + protocol="Simple", + ): + # Reduce ring + for step in range(0, size - 1): + for index in range(0, size): + rank = (index + step) % size + next_rank = (index + step + 1) % size + c = chunk(rank, Buffer.input, index) + c.signal(next_rank, Buffer.input, index, 0) + prev_rank = (index + step - 1) % size + c = chunk(rank, Buffer.input, (index + size - 1) % size) + c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) + c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0) + + # Propagate ring + for step in range(-1, size - 2): + for index in range(0, size): + rank = (index + step) % size + c = chunk(rank, Buffer.input, index) + next_rank = (index + step + 1) % size + c.put(next_rank, Buffer.input, index, sendtb=0) + c.signal(next_rank, Buffer.input, index, 0) + prev_rank = (index + step - 1) % size + c = chunk(rank, Buffer.input, (index + size - 1) % size) + c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("num_gpus", type=int, help="number of gpus") +parser.add_argument("instances", type=int, help="number of instances") +args = parser.parse_args() + +allreduce_ring(args.num_gpus, args.instances) diff --git a/python/examples/send_recv_packet.py b/python/examples/send_recv_packet.py new file mode 100644 index 000000000..f0272344e --- /dev/null +++ b/python/examples/send_recv_packet.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.collectives import SendRecv +from mscclpp.language.buffer import Buffer +from mscclpp.language.types import ChannelType + + +def send_recv(instances): + """ + Send and receive data between two ranks using proxy channels, with LL protocol and double scratch buffer. + Steps: + 1. Each rank sends a chunk to every other rank's scratch buffer with packet format via proxy channel. + 2. Wait for the data to be received, then copy it to the output buffer. + """ + size = 2 + chunksperloop = 1 + collective = SendRecv(size, chunksperloop, False) + with MSCCLPPProgram( + "send_recv", + collective, + size, + instances, + protocol="LL", + use_double_scratch_buffer=True, + ): + for r in range(size): + for nghr in range(size): + if nghr == r: + continue + c = chunk(r, Buffer.input, 0) + c.put_packet( + nghr, + "scratch", + 1, + sendtb=0, + chan_type=ChannelType.proxy, + temp_buffer="scratch", + temp_buffer_index=0, + ) + + for r in range(size): + c = chunk(r, "scratch", 1) + c.copy_packet(r, Buffer.output, 0, sendtb=0) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("instances", type=int, help="number of instances") + +args = parser.parse_args() + +send_recv(args.instances) diff --git a/python/examples/send_recv_proxy.py b/python/examples/send_recv_proxy.py new file mode 100644 index 000000000..ec6baee99 --- /dev/null +++ b/python/examples/send_recv_proxy.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from mscclpp.language import * +from mscclpp.language.buffer import Buffer +from mscclpp.language.collectives import SendRecv +from mscclpp.language.types import ChannelType + + +def send_recv(instances): + """ + Send and receive data between two ranks using proxy channels. + steps: + 1. Each rank sends a chunk to the other rank's scratch buffer and signals the other rank that the data has been sent. + 2. Wait for the data to be received then copy it to the output buffer. + """ + size = 2 + chunksperloop = 1 + collective = SendRecv(size, chunksperloop, False) + with MSCCLPPProgram( + "send_recv", + collective, + size, + instances, + ): + for r in range(size): + for nghr in range(size): + if nghr == r: + continue + c = chunk(r, Buffer.input, 0) + c.put( + nghr, + "scratch", + 1, + sendtb=0, + chan_type=ChannelType.proxy, + ) + c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy) + c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy) + + for r in range(size): + c = chunk(r, "scratch", 1) + c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy) + c.copy(r, Buffer.output, 0, sendtb=0) + + Json() + Check() + + +parser = argparse.ArgumentParser() +parser.add_argument("instances", type=int, help="number of instances") + +args = parser.parse_args() + +send_recv(args.instances) diff --git a/python/mscclpp/language/__init__.py b/python/mscclpp/language/__init__.py new file mode 100644 index 000000000..3616eaa59 --- /dev/null +++ b/python/mscclpp/language/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.program import MSCCLPPProgram, Json, Check, chunk, rank diff --git a/python/mscclpp/language/buffer.py b/python/mscclpp/language/buffer.py new file mode 100644 index 000000000..cc2f01c34 --- /dev/null +++ b/python/mscclpp/language/buffer.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from enum import Enum + + +# Scratch buffer slice with manual indexing +class BufferSlice: + def __init__(self, buf, name): + self.name = name + self.buf = buf + self.offset = -1 # Offset into the global scratch buffer + self.chunks = [] + + # Returns the global index into the scratch buffer + def get_global_index(self, index): + assert self.offset > -1, "set_offset needs to be called first" + return self.offset + index + + def get_buffer(self): + return self.buf + + def instance_size(self): + return len(self.chunks) + + def set_offset(self, offset): + self.offset = offset + + def __getitem__(self, index): + return self.chunks[index] + + def __setitem__(self, index, value): + current_size = len(self.chunks) + while index > current_size: + self.chunks.append(None) + current_size = len(self.chunks) + if index == current_size: + self.chunks.append(value) + else: + self.chunks[index] = value + + def __len__(self): + return len(self.chunks) + + +class Buffer(Enum): + input = "i" + output = "o" + scratch = "s" + + def __str__(self): + return self.value + + def __lt__(self, other): + return self.value < other.value + + def __gt__(self, other): + return self.value < other.value diff --git a/python/mscclpp/language/chunk.py b/python/mscclpp/language/chunk.py new file mode 100644 index 000000000..908ef38a4 --- /dev/null +++ b/python/mscclpp/language/chunk.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from dataclasses import dataclass + + +@dataclass +class Chunk: + origin_rank: int # Rank the chunk initially started at + origin_index: int # Index the chunk initially started at + dst_rank: int = -1 + dst_index: int = -1 + + def reduce(self, dst, chunk): + if type(chunk) is ReduceChunk: + return chunk.reduce(dst, self) + elif type(chunk) is Chunk: + chunks = [self, chunk] + return ReduceChunk(dst, chunks) + else: + raise ValueError("Trying to reduce with chunk of None") + + def __hash__(self): + return hash((self.origin_rank, self.origin_index)) + + def __eq__(self, other): + return ( + type(other) is Chunk and self.origin_rank == other.origin_rank and self.origin_index == other.origin_index + ) + + def __lt__(self, other): + return self.origin_rank < other.origin_rank or ( + self.origin_rank == other.origin_rank and self.origin_index < other.origin_index + ) + + +@dataclass +class ReduceChunk: + creation_rank: int # Rank the Reduce Chunk is created. Necessary since the same ReduceChunk can be created on multiple ranks independently + chunks: list # List of chunks reduced + + def reduce(self, dst, chunk): + if type(chunk) is ReduceChunk: + chunks = self.chunks + chunk.chunks + elif type(chunk) is Chunk: + chunks = self.chunks + [chunk] + else: + raise ValueError("Trying to reduce with chunk of None") + return ReduceChunk(self.creation_rank, chunks) + + def sort(self): + self.chunks.sort() + + def __hash__(self): + self.sort() + return hash((self.creation_rank,) + tuple(self.chunks)) + + # Two reduce chunks are equal if they contain the same list of + # chunks being reduced + def __eq__(self, other): + self.sort() + other.sort() + return self.chunks == other.chunks diff --git a/python/mscclpp/language/collectives.py b/python/mscclpp/language/collectives.py new file mode 100644 index 000000000..67b735ba9 --- /dev/null +++ b/python/mscclpp/language/collectives.py @@ -0,0 +1,339 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.buffer import Buffer +from mscclpp.language.chunk import Chunk, ReduceChunk + + +class Collective: + + def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs): + self.num_ranks = num_ranks + self.chunk_factor = chunk_factor + self.inplace = inplace + self.name = "custom" + # Divide the buffer into num_chunk_groups group + if num_ranks_per_node == -1: + self.num_ranks_per_node = num_ranks + else: + self.num_ranks_per_node = num_ranks_per_node + + # kwargs + # Number of chunk groups: which means we will group n chunks into m groups. + # We will guarantee that the group size is the same. + # But in the same group, the chunk size may be different due to group size + # can not be divided by the number of chunks. + self.num_chunk_groups = kwargs.get("num_chunk_groups", 1) + + def init_buffers(self): + pass + + def check(self, prog): + pass + + def get_buffer_index(self, rank, buffer, index): + return buffer, index + + +class AllToAll(Collective): + + def __init__(self, num_ranks, chunk_factor, inplace): + Collective.__init__(self, num_ranks, chunk_factor, inplace) + self.name = "alltoall" + + def init_buffers(self): + chunks_per_node = self.num_ranks * self.chunk_factor + rank_buffers = [] + for r in range(self.num_ranks): + input_buffer = [None] * chunks_per_node + output_buffer = [None] * chunks_per_node + for index in range(chunks_per_node): + chunk = Chunk(r, index, index // self.chunk_factor, index % self.chunk_factor + r * self.chunk_factor) + input_buffer[index] = chunk + if self.inplace: + buffers = {Buffer.input: input_buffer, Buffer.output: input_buffer} + else: + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} + rank_buffers.append(buffers) + return rank_buffers + + # Expected output buffer for alltoall + def check(self, prog): + chunks_per_node = self.num_ranks * self.chunk_factor + correct = True + for r in range(self.num_ranks): + output = prog.buffers[r][Buffer.output] + for i in range(self.num_ranks): + for ch in range(self.chunk_factor): + index = ch + i * self.chunk_factor + chunk = output[index] + expected_origin_index = ch + r * self.chunk_factor + if chunk is None or chunk.origin_rank != i or chunk.origin_index != expected_origin_index: + print( + f"Rank {r} chunk {index} is incorrect should be chunk({i},{expected_origin_index}) given {chunk}" + ) + correct = False + return correct + + +class AllGather(Collective): + def __init__(self, num_ranks, chunk_factor, inplace): + Collective.__init__(self, num_ranks, chunk_factor, inplace) + self.name = "allgather" + + # Initializes input buffer for an allgather + def init_buffers(self): + rank_buffers = [] + if self.inplace: + # Inplace AllGather only uses the output buffer + for r in range(self.num_ranks): + output_buffer = [None] * (self.num_ranks * self.chunk_factor) + for rank in range(self.num_ranks): + for ch in range(self.chunk_factor): + output_buffer[rank * self.chunk_factor + ch] = Chunk( + rank, ch, -1, rank * self.chunk_factor + ch + ) + buffers = { + Buffer.input: output_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor], + Buffer.output: output_buffer, + } + rank_buffers.append(buffers) + else: + for r in range(self.num_ranks): + input_buffer = [None] * self.chunk_factor + output_buffer = [None] * (self.num_ranks * self.chunk_factor) + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} + rank_buffers.append(buffers) + return rank_buffers + + # Expected output buffer for allgather + def check(self, prog): + correct = True + buf = Buffer.output + for r in range(self.num_ranks): + output = prog.buffers[r][buf] + for i in range(self.num_ranks): + for ch in range(self.chunk_factor): + index = i * self.chunk_factor + ch + chunk = output[index] + if chunk is None: + print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None") + correct = False + elif chunk.origin_rank != i or chunk.origin_index != ch: + print( + f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})" + ) + correct = False + return correct + + def get_buffer_index(self, rank, buffer, index): + # For inplace AllGathers, the input buffer points into the output buffer + if self.inplace and buffer == Buffer.input: + return Buffer.output, index + rank * self.chunk_factor + else: + return buffer, index + + +class AllReduce(Collective): + + def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs): + num_chunk_groups = kwargs.get("num_chunk_groups", num_ranks) + Collective.__init__( + self, num_ranks, chunk_factor, inplace, num_ranks_per_node, num_chunk_groups=num_chunk_groups + ) + self.name = "allreduce" + + def init_buffers(self): + chunks_per_node = self.chunk_factor + rank_buffers = [] + for r in range(self.num_ranks): + input_buffer = [] + output_buffer = [None] * chunks_per_node + for c in range(chunks_per_node): + # Chunks start at rank r index c, and ends on all ranks (-1) at index r + input_buffer.append(Chunk(r, c, -1, c)) + # Input and output buffer are the same. + if self.inplace: + buffers = {Buffer.input: input_buffer, Buffer.output: input_buffer} + else: + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} + rank_buffers.append(buffers) + return rank_buffers + + def check(self, prog): + chunks_per_node = self.chunk_factor + expected_chunks = [] + buf = Buffer.input if self.inplace else Buffer.output + + for c in range(chunks_per_node): + chunk = ReduceChunk(-1, []) + for r in range(self.num_ranks): + chunk = chunk.reduce(-1, Chunk(r, c)) + expected_chunks.append(chunk) + + correct = True + for r in range(self.num_ranks): + output = prog.buffers[r][buf] + for c in range(chunks_per_node): + chunk = output[c] + if chunk is None or chunk != expected_chunks[c]: + print( + f"Rank {r} chunk {c} is incorrect should be ReduceChunk index {c} from all ranks, given {chunk}" + ) + correct = False + return correct + + def get_buffer_index(self, rank, buffer, index): + if self.inplace and buffer == Buffer.output: + return Buffer.input, index + else: + return buffer, index + + +class ReduceScatter(Collective): + def __init__(self, num_ranks, chunk_factor, inplace): + Collective.__init__(self, num_ranks, chunk_factor, inplace) + self.name = "reducescatter" + + def init_buffers(self): + rank_buffers = [] + for r in range(self.num_ranks): + if self.inplace: + input_buffer = [] + for i in range(self.num_ranks): + for c in range(self.chunk_factor): + input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c)) + buffers = {Buffer.input: input_buffer} + rank_buffers.append(buffers) + else: + input_buffer = [] + output_buffer = [None] * self.chunk_factor + for i in range(self.num_ranks): + for c in range(self.chunk_factor): + input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c)) + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} + rank_buffers.append(buffers) + return rank_buffers + + def check(self, prog): + expected_chunks = [] + buf = Buffer.input if self.inplace else Buffer.output + for c in range(self.num_ranks * self.chunk_factor): + chunk = ReduceChunk(-1, []) + for r in range(self.num_ranks): + chunk = chunk.reduce(-1, Chunk(r, c)) + expected_chunks.append(chunk) + + correct = True + for r in range(self.num_ranks): + output = prog.buffers[r][buf] + for c in range(self.chunk_factor): + correct_idx = r * self.chunk_factor + c + if self.inplace: + c = correct_idx + chunk = output[c] + if chunk is None or chunk != expected_chunks[correct_idx]: + print(f"Rank {r} chunk {c} is incorrect should be index {correct_idx} from all ranks given {chunk}") + correct = False + return correct + + def get_buffer_index(self, rank, buffer, index): + # For inplace ReduceScatter the output buffer is a pointer into the input buffer + if self.inplace and buffer == Buffer.output: + return Buffer.input, index + rank * self.chunk_factor + else: + return buffer, index + + +# SendRecv is a collective that sends a chunk from one rank to another +# It is used to test the correctness of the send and receive instructions +class SendRecv(Collective): + def __init__(self, num_ranks, chunk_factor, inplace): + assert num_ranks == 2, "SendRecv only supports 2 ranks" + Collective.__init__(self, num_ranks, chunk_factor, inplace) + self.name = "sendrecv" + + def init_buffers(self): + rank_buffers = [] + for r in range(self.num_ranks): + input_buffer = [None] * self.chunk_factor + output_buffer = [None] * self.chunk_factor + for c in range(self.chunk_factor): + input_buffer[c] = Chunk(r, c, -1, c) + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} + rank_buffers.append(buffers) + return rank_buffers + + def check(self, prog): + correct = True + buff_type = Buffer.input if self.inplace else Buffer.output + for r in range(self.num_ranks): + output = prog.buffers[r][buff_type] + for c in range(self.chunk_factor): + chunk = output[c] + if chunk is None or chunk.origin_rank != 1 - r or chunk.origin_index != c: + print(f"Rank {r} chunk {c} is incorrect should be ({1 - r}, {c}) given {chunk}") + correct = False + + return correct + + def get_buffer_index(self, rank, buffer, index): + if self.inplace and buffer == Buffer.output: + return Buffer.input, index + return buffer, index + + +class Broadcast(Collective): + def __init__(self, num_ranks, chunk_factor, inplace, root): + Collective.__init__(self, num_ranks, chunk_factor, inplace, root) + self.name = "broadcast" + self.root = root + + # Initializes input buffer for an broadcast + def init_buffers(self): + rank_buffers = [] + if self.inplace: + # Inplace broadcast only uses the input buffer + for r in range(self.num_ranks): + input_buffer = [None] * (self.chunk_factor) + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(self.root, ch, -1, ch) + buffers = { + Buffer.input: input_buffer, + Buffer.output: input_buffer, + } + rank_buffers.append(buffers) + else: + for r in range(self.num_ranks): + input_buffer = [None] * self.chunk_factor + output_buffer = [None] * self.chunk_factor + if r == self.root: + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(self.root, ch, -1, ch) + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} + rank_buffers.append(buffers) + return rank_buffers + + # Expected output buffer for broadcast + def check(self, prog): + correct = True + buf = Buffer.output + for r in range(self.num_ranks): + output = prog.buffers[0][buf] + for ch in range(self.chunk_factor): + index = ch + chunk = output[index] + if chunk is None: + print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None") + correct = False + elif chunk.origin_rank != self.root or chunk.origin_index != ch: + print( + f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})" + ) + correct = False + return correct + + def get_buffer_index(self, rank, buffer, index): + return buffer, index diff --git a/python/mscclpp/language/dag/__init__.py b/python/mscclpp/language/dag/__init__.py new file mode 100644 index 000000000..616e396f1 --- /dev/null +++ b/python/mscclpp/language/dag/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.dag.instruction_dag import InstructionDAG +from mscclpp.language.dag.lower import DagLower +from mscclpp.language.dag.optimizer import DagOptimizer diff --git a/python/mscclpp/language/dag/instruction_dag.py b/python/mscclpp/language/dag/instruction_dag.py new file mode 100644 index 000000000..dcc1189ca --- /dev/null +++ b/python/mscclpp/language/dag/instruction_dag.py @@ -0,0 +1,373 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from collections import defaultdict +from mscclpp.language.buffer import Buffer +from mscclpp.language.types import ( + Channel, + ChannelType, + ChunkRef, + Instruction, + Op, +) + + +class InstructionDAG: + def __init__(self, num_ranks: int, buffers: list): + self.num_ranks = num_ranks + self.buffers = buffers + # State for the actual instruction DAG + self.operations = {} # slot -> operations + self.last_writer = {} # slot -> last writing op + self.last_readers = defaultdict(list) # slot -> list of last reading ops + # State for the MSCCLPP-IR + self.tbs = [] + for _ in range(num_ranks): + self.tbs.append({}) + self.tb_mapping = {} + self.num_channels = [1] * num_ranks + self.tb_steps = [{} for _ in range(num_ranks)] + + def convert_set_list(self): + ops = [] + visited = set() + for slot, op in self.operations.items(): + if op.inst == Instruction.start: + op.next = list(op.next) + for o in op.next: + ops.append(o) + elif op.inst != Instruction.copy: + ops.append(op) + + while len(ops) > 0: + op = ops[0] + if op not in visited: + visited.add(op) + op.next = list(op.next) + ops = ops[1:] + op.next + else: + ops = ops[1:] + return visited + + def complete_channels(self): + send_op = [Instruction.put, Instruction.signal, Instruction.put_packet] + recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy] + group_send_op = [Instruction.group_store] + group_recv_op = [Instruction.group_load_reduce] + for rank, rank_tbs in enumerate(self.tbs): + for tbid, tb in rank_tbs.items(): + chans = set() + for op in tb.ops: + if op.inst == Instruction.barrier: + continue + if op.src != None: + src_buffer = ( + Buffer.scratch + if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output + else op.src.buffer + ) + if op.dst != None: + dst_buffer = ( + Buffer.scratch + if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output + else op.dst.buffer + ) + if op.channel_type == ChannelType.nvls: + if op.inst in group_send_op: + ranks = [dst[0].rank for dst in op.dsts] + chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks) + chans.add(chan) + elif op.inst in group_recv_op: + ranks = [src[0].rank for src in op.srcs] + chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks) + chans.add(chan) + else: + if op.inst in send_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank) + chans.add(chan) + elif op.inst in recv_op: + chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank) + chans.add(chan) + tb.channels = list(chans) + + # InstructionDAG - builds the roots of the DAG + def add_start(self, rank, buffer, index, ref): + slot = (rank, buffer, index) + op = Op(Instruction.start, rank, ref, ref, next=set(), prev=set()) + self.operations[slot] = op + self.last_writer[slot] = op + + # InstructionDAG - adds a copy node + def add_copy(self, rank, send_ref, recv_ref, tb, trans_from_packet=False, trans_to_packet=False): + tb_step = self._get_tb_step(rank, tb) + if trans_from_packet: + op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + elif trans_to_packet: + op = Op( + Instruction.transform_to_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step + ) + else: + op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + # Sending part of copy [Read] + self._read(rank, srcbuffer, srcindex, size, op) + # Receiving part of copy [Write] + self._write(rank, dstbuffer, dstindex, size, op) + return op + + # InstructionDAG - adds a redduce node + def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False): + tb_step = self._get_tb_step(rank, tb) + if use_packet: + op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + else: + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step) + dstbuffer = recv_ref.buffer + dstindex = recv_ref.index + srcbuffer = send_ref.buffer + srcindex = send_ref.index + size = recv_ref.size + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + # Sending part of reduce + self._read(rank, srcbuffer, srcindex, size, op) + # Reduce part of copy + self._write(rank, dstbuffer, dstindex, size, op, read=True) + return op + + # InstructionDAG - adds a put node + def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False): + tb_step = self._get_tb_step(rank, tb) + if use_packet: + op = Op( + Instruction.put_packet, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + else: + op = Op( + Instruction.put, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + return op + + def add_get(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step + ) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + return op + + # InstructionDAG - adds a signal node. + def add_signal(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.signal, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + # treat signal as a write. signal acts as a barrier for the next instruction which prevents the + # below instructions to be scheduled above the signal instruction. + self._write(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + return op + + def add_flush(self, rank, send_ref, recv_ref, tb): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.flush, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ChannelType.proxy, + step=tb_step, + ) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + return op + + def add_wait(self, rank, dst_ref, src_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step + ) + buffer = dst_ref.buffer + index = dst_ref.index + size = dst_ref.size + self._write(rank, buffer, index, size, op) + op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step)) + op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step)) + return op + + def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.read_reduce_copy, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + self._write(rank, buffer, index, size, op, read=True) + return op + + def add_barrier(self, rank, tb_list, barrier_id): + buffers = self.buffers[rank] + for tb in tb_list: + tb_step = self._get_tb_step(rank, tb) + extra = {"tb_list": tb_list, "barrier_id": barrier_id} + op = Op(Instruction.barrier, rank, None, None, next=set(), prev=set(), tb=tb, step=tb_step, extra=extra) + for buffer_type, buffer in buffers.items(): + self._write(rank, buffer_type, 0, len(buffer), op) + + def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.group_load_reduce, + rank, + recv_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + # treat recv_ref as src for group_load_reduce + op.srcs.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + for send_ref in send_refs: + op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + buffer = recv_ref.buffer + index = recv_ref.index + size = recv_ref.size + self._write(rank, buffer, index, size, op, read=True) + + def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type): + tb_step = self._get_tb_step(rank, tb) + op = Op( + Instruction.group_store, + rank, + send_ref, + send_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + # treat send_ref as dst for group_store + op.dsts.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) + for recv_ref in recv_refs: + op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step)) + buffer = send_ref.buffer + index = send_ref.index + size = send_ref.size + self._read(rank, buffer, index, size, op) + return op + + def _get_tb_step(self, rank: int, tb: int): + if tb in self.tb_steps[rank]: + self.tb_steps[rank][tb] += 1 + return self.tb_steps[rank][tb] + else: + self.tb_steps[rank][tb] = 0 + return 0 + + # InstructionDAG helper - identifies the dependencies for a write-type operation (recv, copy, rrc, reduce) + def _write(self, rank, buffer, index, size, op, read=False): + prev_ops = set() + for i in range(index, index + size): + slot = (rank, buffer, i) + if read: + assert slot in self.last_writer, f"Destination slot has never been written before a reduce {op}" + + # First write to this slot + if slot not in self.operations: + self.operations[slot] = op + + # If there are active readers - these are the previous operations + # Else the previous operation is the last write (if there is one) + readers = self.last_readers[slot] + if len(readers) > 0: + prev_ops.update(readers) + elif slot in self.last_writer: + prev_ops.add(self.last_writer[slot]) + + # Set the last_writer to this op, and clear all readers + self.last_writer[slot] = op + self.last_readers[slot] = [] + + # Update the next pointer of the previous ops + for prev_op in prev_ops: + prev_op.next.add(op) + op.prev.add(prev_op) + + # InstructionDAG helper - identifies the dependencies for read-type operations (send, copy, reduce) + def _read(self, rank, buffer, index, size, op): + prev_ops = set() + for i in range(index, index + size): + slot = (rank, buffer, i) + assert slot in self.last_writer, f"Slot has never been written before a read-type {op}" + # The previous operation for a reader is the last write to the slot + writer = self.last_writer[slot] + prev_ops.add(writer) + self.last_readers[slot].append(op) + + # Update the next pointer of the previous ops + for prev_op in prev_ops: + prev_op.next.add(op) + op.prev.add(prev_op) diff --git a/python/mscclpp/language/dag/lower.py b/python/mscclpp/language/dag/lower.py new file mode 100644 index 000000000..0283c6f56 --- /dev/null +++ b/python/mscclpp/language/dag/lower.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import copy +from typing import List +from mscclpp.language.buffer import Buffer +from mscclpp.language.dag.instruction_dag import InstructionDAG +from mscclpp.language.types import ChunkRef, Gpu, Instruction, Op, ReplicationPolicy, Threadblock + + +class DagLower: + def __init__(self, dag: InstructionDAG): + self.dag = dag + self.instanced_tbs = [] + + def lower(self, instances: int, replication_policy: ReplicationPolicy): + self._infer_dependencies() + self._lower_buffers(instances) + self._replicate(instances, replication_policy) + return self._lower_tbs() + + def _replicate(self, instances: int, replication_policy: ReplicationPolicy): + # update op step + for rank, rank_tbs in enumerate(self.dag.tbs): + for _, tb in rank_tbs.items(): + for id, op in enumerate(tb.ops): + op.step = id + + if instances == 1: + self.instanced_tbs = self.dag.tbs + return + + self.instanced_tbs = [] + for _ in range(self.dag.num_ranks): + self.instanced_tbs.append({}) + + def get_new_index(rank, buffer, index, size, i): + if replication_policy == ReplicationPolicy.interleaved: + return index * instances + i * size + return len(self.dag.buffers[rank][buffer]) * i + index + + def get_instance_ref(ref): + if ref is None: + return None + iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i) + iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size) + return iref + + def update_extra(op, ori_op): + if op.inst == Instruction.barrier: + tb_list = ori_op.extra["tb_list"] + new_tb_list = [tb * instances + i for tb in tb_list] + op.extra["tb_list"] = new_tb_list + op.extra["barrier_id"] = ori_op.extra["barrier_id"] * instances + i + + for i in range(instances): + # Generate all the threadblocks and ops + for rank, rank_tbs in enumerate(self.dag.tbs): + # rank_channels = self.num_channels[rank] + for tbid, tb in rank_tbs.items(): + itbid = tbid * instances + i + itb = Threadblock(id=itbid) + itb.ops = [None] * len(tb.ops) + for s, op in enumerate(tb.ops): + isrc = get_instance_ref(op.src) + idst = get_instance_ref(op.dst) + idepends = [] + # Note: We don't need the fill out the rest of the metadata since replication is the last optimization + iop = Op( + op.inst, + op.rank, + isrc, + idst, + idepends, + op.step, + itbid, + channel_type=op.channel_type, + extra=copy.deepcopy(op.extra), + ) + update_extra(iop, op) + itb.ops[s] = iop + for src, step in op.srcs: + isrc = get_instance_ref(src) + iop.srcs.append((isrc, step)) + for dst, step in op.dsts: + idst = get_instance_ref(dst) + iop.dsts.append((idst, step)) + for chan in tb.channels: + itb.channels.append(chan) + self.instanced_tbs[op.rank][itbid] = itb + + # Redo dependency analysis + for rank, rank_tbs in enumerate(self.dag.tbs): + for tbid, tb in rank_tbs.items(): + for i in range(instances): + itbid = tbid * instances + i + itb = self.instanced_tbs[rank][itbid] + for op, iop in zip(tb.ops, itb.ops): + iop.depends = [None] * len(op.depends) + for s, dep in enumerate(op.depends): + dep_tbid = dep.tb + dep_itbid = dep_tbid * instances + i + dep_step = dep.step + iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step] + + # Convert local scratch buffers to index into one global scratch buffer + + def _lower_chunk(self, chunk): + if chunk is not None and chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output: + buffer = self.dag.buffers[chunk.rank][chunk.buffer].get_buffer() + index = self.dag.buffers[chunk.rank][chunk.buffer].get_global_index(chunk.index) + return ChunkRef(chunk.rank, buffer, index, chunk.size) + return chunk + + # Assigns each scratch buffer an offset into the global scratch buffer + def _lower_buffers(self, instances): + for rank_buffers in self.dag.buffers: + offset = 0 + for key, buf in rank_buffers.items(): + if key is not Buffer.input and key is not Buffer.output: + buf.set_offset(offset) + offset += buf.instance_size() * instances + + def _lower_tbs(self) -> List[Gpu]: + gpus = [] + for rank, rank_tbs in enumerate(self.instanced_tbs): + lowered_tbs = {} + for tbid, tb in rank_tbs.items(): + for op in tb.ops: + op.src = self._lower_chunk(op.src) + op.dst = self._lower_chunk(op.dst) + srcs = sorted(op.srcs, key=lambda x: x[1]) + dsts = sorted(op.dsts, key=lambda x: x[1]) + op.srcs = [self._lower_chunk(src[0]) for src in srcs] + op.dsts = [self._lower_chunk(dst[0]) for dst in dsts] + lowered_tbs[tbid] = tb + gpus.append(Gpu(rank, list(lowered_tbs.values()))) + return gpus + + def _infer_dependencies(self): + visited = set() + for _, op in self.dag.operations.items(): + if op in visited: + continue + frontier = [op] + while len(frontier) > 0: + op = frontier[0] + if op in visited: + frontier = frontier[1:] + continue + # Dependencies for every op is the same as the ops that are stored in prev + # Filter out dependencies that are satisified by tbs executing ops sequentially + # If multiple dependent ops from the same tb keep the one that happens last + depends = {} + for dep_op in list(op.prev): + if dep_op.inst != Instruction.start: + tb = dep_op.tb + if tb not in depends or dep_op.step > depends[tb].step: + depends[tb] = dep_op + op.depends = list(depends.values()) + visited.add(op) + frontier = frontier[1:] + op.next diff --git a/python/mscclpp/language/dag/optimizer.py b/python/mscclpp/language/dag/optimizer.py new file mode 100644 index 000000000..62fc0f5e8 --- /dev/null +++ b/python/mscclpp/language/dag/optimizer.py @@ -0,0 +1,405 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.utils import ( + buf_dst_src_match, + circular_dep_after_merge, + merge_op, + remove_op, + same_chan_type, + same_count, + same_buf_dst, + same_buf_src, + same_src_dst_buffer_type, + same_tb, + all_prevs_visited_after_merge, +) +from mscclpp.language.dag.instruction_dag import InstructionDAG +from mscclpp.language.types import ChunkRef, ChannelType, Instruction, Op, Threadblock + + +class _InstructionOptimizer: + + def try_merge_same_instructions( + self, + op: Op, + next_op: Op, + tb: Threadblock, + queue: list, + expected_next_inst: Instruction, + same_buf_func: callable, + ) -> bool: + """ + Attempts to merge two instruction if conditions are met. + :param op: The current operation. + :param next_op: The next operation to potentially merge with. + :param tb: The thread block containing the operations. + :param queue: The queue of operations. + :param expected_next_inst: The instruction type expected for the next operation. + :param same_buf_func: The function to check if the buffer is the same (same_buf_dst or same_buf_src). + :return: True if operations are merged, False otherwise. + """ + if ( + next_op.inst == expected_next_inst + and same_tb(op, next_op) + and same_buf_func(op, next_op) + and same_count(op, next_op) + and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) + and all_prevs_visited_after_merge(op, next_op) + ): + # Append the source chunks from next_op + op.srcs.append( + ( + ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size), + next_op.step, + ) + ) + # For 'signal' and 'wait' instructions, append destination chunks too + if expected_next_inst in [Instruction.signal, Instruction.wait, Instruction.flush]: + op.dsts.append( + ( + ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), + next_op.step, + ) + ) + # Merge operations + merge_op(op, next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + return True + return False + + def try_compact_instructions( + self, op: Op, tb: Threadblock, queue: list, inst_type: Instruction, same_src_dst_func: callable + ) -> bool: + """ + Try to campact the instructions with the same instruction type. This optimization will + compact multiple instructions of the same type into a single instruction. + :param op: The current operation. + :param seq_op: The sequential operation to merge with. + :param tb: The task block containing the operations. + :param queue: The queue of operations. + :param inst_type: The type of the instruction being processed (get, put, put_packet). + :return: True if operations are merged, False otherwise. + """ + if len(queue) > 1: + seq_op = queue[1] + if ( + seq_op.inst == inst_type + and same_src_dst_func(op, seq_op) + and same_chan_type(op, seq_op) + and same_count(op, seq_op) + and not circular_dep_after_merge(op, seq_op) + and all_prevs_visited_after_merge(op, seq_op) + ): + # Append the source and destination chunks from seq_op + op.dsts.append( + ( + ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size), + seq_op.step, + ) + ) + op.srcs.append( + ( + ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size), + seq_op.step, + ) + ) + merge_op(op, seq_op) + tb.ops.remove(seq_op) + queue.remove(seq_op) + return True + return False + + def try_fuse_with_put(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool: + """ + Attempts to fuse 'put' operations with other operations like read_reduce_copy, reduce, etc. + :param op: The current operation. + :param next_op: The next operation to potentially merge with. + :param tb: The thread block containing the operations. + :param queue: The queue of operations. + :param inst_type: The type of the instruction being processed. + :param chan_type: Channel type if applicable. + :return: True if operations are merged, False otherwise. + """ + if ( + (next_op.inst == Instruction.put or next_op.inst == Instruction.put_packet) + and same_tb(op, next_op) + and same_count(op, next_op) + and buf_dst_src_match(op, next_op) + and next_op.channel_type == ChannelType.sm + and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm) + and not circular_dep_after_merge(op, next_op) + and all_prevs_visited_after_merge(op, next_op) + ): + if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer: + return False + # Adjust instruction type and channel if needed + if op.inst == Instruction.read_reduce_copy: + op.inst = Instruction.read_reduce_copy_send + elif op.inst == Instruction.reduce: + op.inst = Instruction.reduce_send + op.channel_type = ChannelType.sm + elif op.inst == Instruction.reduce_packet: + op.inst = Instruction.reduce_send_packet + op.channel_type = ChannelType.sm + # Append the destination chunk from next_op + op.dsts.append( + ( + ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size), + next_op.step, + ) + ) + # Merge operations + merge_op(op, next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + return True + return False + + def try_fuse_instructions_using_proxy_channel( + self, op: Op, next_op: Op, tb: Threadblock, queue: list, expected_next_inst: Instruction + ) -> bool: + """ + Attempts to fuse operations which using proxy channel. + :param op: The current operation. + :param next_op: The next operation to potentially merge with. + :param tb: The thread block containing the operations. + :param queue: The queue of operations. + :param expected_next_inst: The instruction type expected for the next operation. + :return: True if operations are merged, False otherwise. + """ + if ( + next_op.inst == expected_next_inst + and same_tb(op, next_op) + and same_count(op, next_op) + and same_buf_dst(op, next_op) + and same_buf_src(op, next_op) + and same_chan_type(op, next_op) + and op.channel_type == ChannelType.proxy + and not circular_dep_after_merge(op, next_op) + and all_prevs_visited_after_merge(op, next_op) + ): + if op.inst == Instruction.put and next_op.inst == Instruction.signal: + op.inst = Instruction.put_with_signal + elif op.inst == Instruction.put_with_signal and next_op.inst == Instruction.flush: + op.inst = Instruction.put_with_signal_and_flush + # Merge operations + merge_op(op, next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + return True + return False + + def try_fuse_with_group_store(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool: + """ + Attempts to fuse 'gruop_load_reduce' operations with 'group_store' operations. + :param op: The current operation. + :param next_op: The next operation to potentially merge with. + :param tb: The thread block containing the operations. + :param queue: The queue of operations. + :return: True if operations are merged, False otherwise. + """ + if ( + next_op.inst == Instruction.group_store + and same_count(op, next_op) + and buf_dst_src_match(op, next_op) + and same_chan_type(op, next_op) + and not circular_dep_after_merge(op, next_op) + and all_prevs_visited_after_merge(op, next_op) + ): + # Append the destination chunk from next_op + op.inst = Instruction.group_load_reduce_store + op.src = next_op.src + for dst in next_op.dsts: + op.dsts.append(dst) + # Merge operations + merge_op(op, next_op) + tb.ops.remove(next_op) + queue.remove(next_op) + return True + return False + + def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool: + if condition: + remove_op(pending_remove_op) + return True + return False + + +class DagOptimizer: + + def __init__(self, instruction_dag: InstructionDAG): + self.optimizer = _InstructionOptimizer() + self.dag = instruction_dag + + def remove_redundant_signal_wait(self): + # For packet ops, we can remove signal/wait + for rank, rank_tbs in enumerate(self.dag.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + if op.inst == Instruction.put_packet: + for next_op in op.next: + fused = self.optimizer.try_remove_op(next_op, next_op.inst == Instruction.signal) + if fused: + break + elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet: + for prev_op in op.prev: + fused = self.optimizer.try_remove_op(prev_op, prev_op.inst == Instruction.wait) + if fused: + break + if fused: + continue + queue = queue[1:] + + def fuse_instructions(self): + self._fuse_instructions_using_proxy_channel() + self._fuse_same_instructions() + self._optimize_rrcs_rs() + self._optimize_group_ops() + self._compact_instructions() + + # put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) + # -> putWithSignal(src, sbuf, si, dst, dbuf, di) + # put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) flush(src, sbuf, si, dst, dbuf, di) + # -> putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di) + def _fuse_instructions_using_proxy_channel(self): + inst_followup_map = { + Instruction.put: Instruction.signal, + Instruction.put_with_signal: Instruction.flush, + } + for rank, rank_tbs in enumerate(self.dag.tbs): + for tbid, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + if op.inst in inst_followup_map: + for next_op in op.next: + fused = self.optimizer.try_fuse_instructions_using_proxy_channel( + op, next_op, tb, queue, inst_followup_map[op.inst] + ) + if fused: + break + if fused: + continue + queue = queue[1:] + + # rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di) + # signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di]) + # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) + # reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di) + # reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di) + def _fuse_same_instructions(self): + # Mapping instruction to their respective condition checks and same buffer function + instruction_handlers = { + Instruction.read_reduce_copy: same_buf_dst, + Instruction.reduce: same_buf_dst, + Instruction.reduce_packet: same_buf_dst, + Instruction.signal: same_buf_src, + Instruction.wait: same_buf_dst, + } + + for _, rank_tbs in enumerate(self.dag.tbs): + for _, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + inst_type = op.inst + if inst_type in instruction_handlers: + for next_op in op.next: + same_buf_func = instruction_handlers[inst_type] + if self.optimizer.try_merge_same_instructions( + op, next_op, tb, queue, inst_type, same_buf_func + ): + fused = True + break + if fused: + continue + queue = queue[1:] + + # rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_) + # reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_) + def _optimize_rrcs_rs(self): + inst_types = [ + Instruction.read_reduce_copy, + Instruction.reduce, + Instruction.reduce_packet, + Instruction.read_reduce_copy_send, + Instruction.reduce_send, + Instruction.reduce_send_packet, + ] + for _, rank_tbs in enumerate(self.dag.tbs): + for _, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + if op.inst in inst_types: + for next_op in op.next: + fused = self.optimizer.try_fuse_with_put(op, next_op, tb, queue) + if fused: + break + if fused: + continue + queue = queue[1:] + + # glre(srcs, sbuf, si, _, _, _), gstore (_, _, _, dsts, dbuf, di) -> glres(srcs, sbuf, si, dsts, dbuf, di) + def _optimize_group_ops(self): + inst_types = [ + Instruction.group_load_reduce, + ] + for _, rank_tbs in enumerate(self.dag.tbs): + for _, tb in rank_tbs.items(): + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + if op.inst in inst_types: + for next_op in op.next: + fused = self.optimizer.try_fuse_with_group_store(op, next_op, tb, queue) + if fused: + break + if fused: + continue + queue = queue[1:] + + # merge ops which are independent of other operations and no other operations in between + # get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di]) + # put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di]) + # putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di) + # putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di) + # -> putWithSignal/putWithSignalAndFlush(list[src,sbuf,si], list[dst,dbuf,di]) + # wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_]) + def _compact_instructions(self): + campactable_inst = [ + Instruction.get, + Instruction.put, + Instruction.put_packet, + Instruction.put_with_signal, + Instruction.put_with_signal_and_flush, + Instruction.signal, + Instruction.flush, + Instruction.wait, + ] + for _, rank_tbs in enumerate(self.dag.tbs): + for _, tb in rank_tbs.items(): + if tb.id == -1: + continue + queue = list(tb.ops) + while len(queue) > 0: + op = queue[0] + fused = False + if op.inst in campactable_inst: + fused = self.optimizer.try_compact_instructions( + op, tb, queue, op.inst, same_src_dst_buffer_type + ) + + if fused: + continue + queue = queue[1:] diff --git a/python/mscclpp/language/ir.py b/python/mscclpp/language/ir.py new file mode 100644 index 000000000..3b84b5298 --- /dev/null +++ b/python/mscclpp/language/ir.py @@ -0,0 +1,534 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import asdict, dataclass +import json +from typing import Dict, List, Optional, Union + +from mscclpp.language.types import Buffer, ChannelType, Op, Program, Instruction + +_local_src_insts_mscclpp: set = { + Instruction.put, + Instruction.put_packet, + Instruction.signal, + Instruction.flush, + Instruction.put_with_signal, + Instruction.put_with_signal_and_flush, + Instruction.copy, + Instruction.copy_packet, + Instruction.transform_to_packet, + Instruction.reduce, + Instruction.reduce_packet, + Instruction.reduce_send, + Instruction.reduce_send_packet, + Instruction.group_load_reduce_store, + Instruction.group_store, +} +_local_dst_insts_mscclpp: set = { + Instruction.get, + Instruction.wait, + Instruction.read_reduce_copy, + Instruction.copy, + Instruction.copy_packet, + Instruction.transform_to_packet, + Instruction.reduce, + Instruction.read_reduce_copy_send, + Instruction.reduce_send, + Instruction.reduce_packet, + Instruction.reduce_send_packet, + Instruction.group_load_reduce_store, + Instruction.group_load_reduce, +} + +_insts_no_need_sync_barrier: set = { + Instruction.copy_packet, + Instruction.reduce_packet, + Instruction.reduce_send_packet, + Instruction.barrier, +} + + +def ir_to_json(program: Program): + # Figure out sizes of buffers based on usage + buffer_sizes = defaultdict(lambda: 0) + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + if op.inst in _local_src_insts_mscclpp: + key = (gpu.rank, op.src.buffer) + buffer_sizes[key] = max(buffer_sizes[key], op.src.index + op.src.size) + for src in op.srcs: + key = (gpu.rank, src.buffer) + buffer_sizes[key] = max(buffer_sizes[key], src.index + src.size) + if op.inst in _local_dst_insts_mscclpp: + key = (gpu.rank, op.dst.buffer) + buffer_sizes[key] = max(buffer_sizes[key], op.dst.index + op.dst.size) + # ignore remote buffers + if ( + op.inst != Instruction.read_reduce_copy_send + and op.inst != Instruction.reduce_send + and op.inst != Instruction.reduce_send_packet + ): + for dst in op.dsts: + key = (gpu.rank, dst.buffer) + buffer_sizes[key] = max(buffer_sizes[key], dst.index + dst.size) + for gpu in program.gpus: + gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks) + gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks) + gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks) + + # Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size. + # Otherwise the offset calculation will be wrong. + if program.protocol == "LL": + max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) + for gpu in program.gpus: + gpu.scratch_chunks = max_scratch + + # get channel info for each GPU and threadblock + for gpu in program.gpus: + gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id) + chan_dict = {} + # the channel key is the tuple (srcBuffer, dstBuffer, type) + for tb in gpu.threadblocks: + for ch in tb.channels: + key = (ch.srcBuffer, ch.dstBuffer, ch.type) + if key not in chan_dict: + chan_dict[key] = [(tb.id, ch.connected_to)] + else: + chan_dict[key].append((tb.id, ch.connected_to)) + for key, value in chan_dict.items(): + chan_dict[key] = sorted(value) + gpu.channels = chan_dict + + # Remove the dependencies of wait after signal. They are actually depends on remote chunk + for gpu in program.gpus: + for tb in gpu.threadblocks: + for op in tb.ops: + if op.inst == Instruction.wait: + op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends)) + + # Filter out redundant dependencies + # e.g. if op1 and op2 depend on op, and op1 happens before op2 + # then op2 does not need to explicitly depend on op + for gpu in program.gpus: + for tb in gpu.threadblocks: + running_depends = [] + for op in tb.ops: + op.depends = list(filter(lambda dep: dep not in running_depends, op.depends)) + running_depends = running_depends + op.depends + + # Do some additional postprocessing of operations: + # - Expand operations with dependencies with no-ops + for gpu in program.gpus: + for tb in gpu.threadblocks: + new_ops = [] + for op in tb.ops: + if op.inst in _insts_no_need_sync_barrier: + new_ops.append(op) + continue + # Expand extra dependencies into nop operations + nop = Op(Instruction.nop, -1, None, None, []) + for i, dep in enumerate(op.depends): + # barrier already syncs all threads + if dep.inst != Instruction.barrier: + nop.depends.append(dep) + if len(new_ops) > 0 and ( + new_ops[-1].inst == Instruction.barrier or new_ops[-1].inst == Instruction.nop + ): + new_ops[-1].depends.extend(nop.depends) + elif len(nop.depends) > 0: + new_ops.append(nop) + new_ops.append(op) + tb.ops = new_ops + + # update step and tid for ops + for gpu in program.gpus: + for tb in gpu.threadblocks: + for i, op in enumerate(tb.ops): + op.step = i + op.tb = tb.id + + # Need to calculate channel info for each GPU + nchannels = 0 + for gpu in program.gpus: + max_tb_channels = 0 + if len(gpu.threadblocks) > 0: + max_tb_channels = max(tb.channel + 1 for tb in gpu.threadblocks) + nchannels = max(nchannels, max_tb_channels) + return _dump_to_json(program) + + +@dataclass +class _JsonInstruction: + name: str + i_buff: Optional[Dict[str, str]] = None + i_cids: Optional[List[Dict[str, Union[int, List[int]]]]] = None + o_buff: Optional[Dict[str, str]] = None + o_cids: Optional[List[Dict[str, Union[int, List[int]]]]] = None + src: Optional[int] = None + srcs: Optional[List[Dict[str, Union[int, str]]]] = None + srcbuff: Optional[str] = None + srcoff: Optional[int] = None + dst: Optional[int] = None + dsts: Optional[List[Dict[str, Union[int, str]]]] = None + dstbuff: Optional[str] = None + dstoff: Optional[int] = None + ctype: Optional[str] = None + cnt: Optional[int] = None + deps: Optional[List[Dict[str, int]]] = None + nthread_blocks: Optional[int] = None + barrier_id: Optional[int] = None + + +class _OpConverter(ABC): + def get_channel_ids(self, chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type): + channel_ids = [] + key = (src_buffer, dst_buffer, chan_type) + if chan_type == ChannelType.nvls: + ranks = [] + for c in chunk_list: + ranks.append(c.rank) + channel_ids.extend( + [{"id": id} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if set(ele) == set(ranks)] + ) + else: + for c in chunk_list: + channel_ids.extend( + [ + {"id": id, "off": c.index} + for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) + if ele == c.rank + ] + ) + return channel_ids + + @abstractmethod + def to_json(self, op: Op) -> _JsonInstruction: + pass + + +class _SignalFlushConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + return _JsonInstruction( + name=op.inst.value, + o_buff=o_buff, + o_cids=dst_channel_ids, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _WaitConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + return _JsonInstruction( + name=op.inst.value, + i_buff=i_buff, + i_cids=src_channel_ids, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _ReadReduceCopyConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + dst = op.dst + src = op.dst # TODO(binyli): fix this + return _JsonInstruction( + name=op.inst.value, + i_buff=i_buff, + dst=dst.rank, + dstbuff=dst.buffer.value, + dstoff=dst.index, + src=src.rank, + srcbuff=src.buffer.value, + srcoff=src.index, + i_cids=src_channel_ids, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _ReadReduceCopySendConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + dst_channel_ids = self.get_channel_ids( + op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type + ) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} + dst = op.dst + src = op.dst # TODO(binyli): fix this + return _JsonInstruction( + name=op.inst.value, + i_buff=i_buff, + i_cids=src_channel_ids, + o_buff=o_buff, + o_cids=dst_channel_ids, + src=src.rank, + srcbuff=src.buffer.value, + srcoff=src.index, + dst=dst.rank, + dstbuff=dst.buffer.value, + dstoff=dst.index, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _ReduceSendConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + dst_channel_ids = self.get_channel_ids( + op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm + ) + o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value} + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + dst = op.dst + src = op.dst # TODO(binyli): fix this + return _JsonInstruction( + name=op.inst.value, + o_buff=o_buff, + o_cids=dst_channel_ids, + src=src.rank, + srcbuff=src.buffer.value, + srcoff=src.index, + srcs=srcs, + dst=dst.rank, + dstbuff=dst.buffer.value, + dstoff=dst.index, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _ReduceConverters(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + dst = op.dst + src = op.dst + return _JsonInstruction( + name=op.inst.value, + srcs=srcs, + dst=dst.rank, + dstbuff=dst.buffer.value, + dstoff=dst.index, + src=src.rank, + srcbuff=src.buffer.value, + srcoff=src.index, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _NopConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + return _JsonInstruction( + name=op.inst.value, + deps=list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)), + ) + + +class _BarrierConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + return _JsonInstruction( + name=op.inst.value, + nthread_blocks=len(op.extra["tb_list"]), + barrier_id=op.extra["barrier_id"], + ) + + +class _PutConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs)) + return _JsonInstruction( + name=op.inst.value, + o_buff=o_buff, + o_cids=dst_channel_ids, + srcs=srcs, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _GetConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value} + dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts)) + return _JsonInstruction( + name=op.inst.value, + i_buff=i_buff, + i_cids=src_channel_ids, + dsts=dsts, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _CopyConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + src = op.src + dst = op.dst + return _JsonInstruction( + name=op.inst.value, + src=src.rank, + srcbuff=src.buffer.value, + srcoff=src.index, + dst=dst.rank, + dstbuff=dst.buffer.value, + dstoff=dst.index, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +class _GroupLoadReduceStoreConverter(_OpConverter): + def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction: + src = op.src + dst = op.dst + src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type) + return _JsonInstruction( + name=op.inst.value, + src=src.rank, + srcbuff=src.buffer.value, + srcoff=src.index, + dst=dst.rank, + dstbuff=dst.buffer.value, + dstoff=dst.index, + i_cids=src_channel_ids, + o_cids=dst_channel_ids, + ctype=op.channel_type.value, + cnt=op.cnt(), + ) + + +_json_converter_map: Dict[Instruction, _OpConverter] = { + Instruction.signal: _SignalFlushConverter(), + Instruction.flush: _SignalFlushConverter(), + Instruction.wait: _WaitConverter(), + Instruction.read_reduce_copy: _ReadReduceCopyConverter(), + Instruction.read_reduce_copy_send: _ReadReduceCopySendConverter(), + Instruction.reduce_send: _ReduceSendConverter(), + Instruction.reduce_send_packet: _ReduceSendConverter(), + Instruction.reduce: _ReduceConverters(), + Instruction.reduce_packet: _ReduceConverters(), + Instruction.nop: _NopConverter(), + Instruction.barrier: _BarrierConverter(), + Instruction.put: _PutConverter(), + Instruction.put_packet: _PutConverter(), + Instruction.put_with_signal: _PutConverter(), + Instruction.put_with_signal_and_flush: _PutConverter(), + Instruction.get: _GetConverter(), + Instruction.copy: _CopyConverter(), + Instruction.copy_packet: _CopyConverter(), + Instruction.transform_to_packet: _CopyConverter(), + Instruction.group_load_reduce_store: _GroupLoadReduceStoreConverter(), +} + + +def _dump_to_json(program: Program): + gpus = [] + + def remove_empty_fields(d): + return {k: v for k, v in d.items() if v not in [None, "", [], {}]} + + max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) + max_input = max(gpu.input_chunks for gpu in program.gpus) + max_output = max(gpu.output_chunks for gpu in program.gpus) + + for id, gpu in enumerate(program.gpus): + gpu_instance = { + "id": id, + "inputChunks": gpu.input_chunks, + "outputChunks": gpu.output_chunks, + "scratchChunks": gpu.scratch_chunks, + "chunkGroups": program.num_chunk_groups, + "threadblocks": [], + "channels": [], + } + for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): + obj = { + "srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer, + "dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer, + "type": type.value, + "connectedTo": [ch[1] for ch in channels], + } + if type == ChannelType.nvls: + obj["connectedTo"] = [sorted(list(peers)) for peers in obj["connectedTo"]] + gpu_instance["channels"].append(obj) + gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) + gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"])) + + # render for GPU NVLS channels + for i, chan in enumerate(gpu_instance["channels"]): + if chan["type"] == "nvls": + buff = chan["srcbuff"] + buffer_size = ( + max_input + if buff == Buffer.input.value + else max_output if buff == Buffer.output.value else max_scratch + ) + gpu_instance["channels"][i] = { + "buff": chan["srcbuff"], + "type": chan["type"], + "rankGroups": [{"size": buffer_size, "ranks": ranks} for ranks in chan["connectedTo"]], + } + + for tb in gpu.threadblocks: + if tb.id < 0: + continue + ops = [] + tb_channels = [] + tb_channel_dict = {} + for (srcBuffer, dstBuffer, type), channels in gpu.channels.items(): + obj = { + "srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer, + "dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer, + "type": type.value, + "chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id], + "connectedTo": [ele[1] for ele in channels if ele[0] == tb.id], + } + if len(obj["chanIds"]) > 0: + tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj + tb_channels.append(obj) + tb_channels = filter(lambda x: x["type"] != "none", tb_channels) + tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"])) + for op in tb.ops: + if op.tb == -1: + continue + instr = _json_converter_map[op.inst].to_json(op, tb_channel_dict) + ops.append(remove_empty_fields(asdict(instr))) + threadblock = { + "id": tb.id, + "ops": ops, + "channels": list( + map( + lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]}, + tb_channels, + ) + ), + } + gpu_instance["threadblocks"].append(threadblock) + gpus.append(gpu_instance) + obj = { + "name": program.name, + "collective": program.collective, + "protocol": program.protocol, + "inplace": program.inplace, + "gpus": gpus, + "num_threads_per_block": program.num_threads_per_block, + "use_double_scratch_buffer": program.use_double_scratch_buffer, + "min_message_size": program.min_message_size, + "max_message_size": program.max_message_size, + } + return json.dumps(obj, indent=2) diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py new file mode 100644 index 000000000..6cf0d15b1 --- /dev/null +++ b/python/mscclpp/language/program.py @@ -0,0 +1,433 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass +from mscclpp.language.collectives import Collective +from mscclpp.language.buffer import * +from mscclpp.language.types import ChannelType, ChunkRef, ReplicationPolicy, Threadblock +from mscclpp.language.ir import * +from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG +from mscclpp.language.rank import Rank + +_current_program = None + + +def _curr(): + global _current_program + if _current_program == None: + raise RuntimeError("No Program in context") + return _current_program + + +# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb) +# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation. +# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first, +# then performance a across tb sync. This is a limitation of current implementation. +class MSCCLPPProgram: + def __init__( + self, + name: str, + collective: Collective, + num_ranks: int, + instances: int, + protocol: str = "Simple", + instr_fusion: bool = True, + replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated, + num_threads_per_block: int = 1024, + use_double_scratch_buffer: bool = False, + min_message_size: int = 0, + max_message_size: int = 2**64 - 1, + ): + self.name = name + self.collective = collective + self.num_ranks = num_ranks + self.instances = instances + self.protocol = protocol + self.instr_fusion = instr_fusion + self.replication_policy = replication_policy + self.num_threads_per_block = num_threads_per_block + self.use_double_scratch_buffer = use_double_scratch_buffer + self.min_message_size = min_message_size + self.max_message_size = max_message_size + assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" + self.run_opt = True # Runs optimization passes + # Initialize the input buffers + self.buffers = collective.init_buffers() + self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) + self.ranks = [] + for r in range(self.num_ranks): + self.ranks.append(Rank(r)) + for index, chunk in enumerate(self.buffers[r][Buffer.input]): + buffer, index = self.collective.get_buffer_index(r, Buffer.input, index) + ref = self.get_ref(r, buffer, index, 1) + # self.chunk_dag.init_chunk(chunk, ref) + self.instr_dag.add_start(r, buffer, index, ref) + + def __enter__(self): + global _current_program + if _current_program != None: + raise RuntimeError("There is already a MSCCLPP Program in context") + _current_program = self + + def __exit__(self, exc_type, exc_value, exc_traceback): + global _current_program + if _current_program != self: + raise RuntimeError("This program is not currently in context") + _current_program = None + + def _convert_to_execution_plan(self): + ops = self.instr_dag.convert_set_list() + ops = sorted(ops, key=lambda x: x.step) + for op in ops: + rank = op.rank + tbid = op.tb + if tbid not in self.instr_dag.tbs[rank]: + self.instr_dag.tbs[rank][tbid] = Threadblock(id=tbid) + tb = self.instr_dag.tbs[rank][tbid] + tb.ops.append(op) + + def get_rank_ref(self, rank): + return RankRef(rank, self) + + # Tracks a send operation on the buffers + def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + db[dst_index + i] = sb[src_index + i] + + # Tracks a reduce operation on the buffers + def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + reduce_chunk = db[dst_index + i] + sent_chunk = sb[src_index + i] + db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) + + def get_ref(self, rank, buffer, index, size): + buffer, index = self.collective.get_buffer_index(rank, buffer, index) + return Ref(rank, buffer, index, size, self) + + def get_chunks(self, rank, buffer, index, size=1): + chunks = [None] * size + for i in range(0, size): + if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]): + chunks[i] = self.buffers[rank][buffer][index + i] + else: + chunks[i] = None + return chunks + + def check_buffer_exists(self, rank, name): + if name not in self.buffers[rank]: + self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) + + # Checks that all chunks that should be on each rank + # are present in the output buffer. + def check(self): + return self.collective.check(self) + + # Lower program to MSCCLPP + def lower(self): + self._convert_to_execution_plan() + self.instr_dag.complete_channels() + dag_optimizer = DagOptimizer(self.instr_dag) + dag_optimizer.remove_redundant_signal_wait() + if self.instr_fusion: + dag_optimizer.fuse_instructions() + dag_lower = DagLower(self.instr_dag) + gpu_prgms = dag_lower.lower(self.instances, self.replication_policy) + program = Program( + self.name, + self.collective.name, + self.collective.inplace, + self.protocol, + gpu_prgms, + self.collective.num_chunk_groups * self.instances, + self.num_threads_per_block, + self.use_double_scratch_buffer, + self.min_message_size, + self.max_message_size, + ) + for gpu in program.gpus: + gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances + if not self.collective.inplace: + gpu.output_chunks = len(self.buffers[gpu.rank][Buffer.output]) * self.instances + return program + + def generate_json(self): + return ir_to_json(self.lower()) + + +def Json(): + print(_curr().generate_json()) + + +@dataclass +class RankRef: + rank: int + prog: MSCCLPPProgram + + def _get_barrier_id(self, tb_list) -> int: + return self.prog.ranks[self.rank].get_barrier_id(tb_list) + + def barrier(self, tb_list): + barrier_id = self._get_barrier_id(tb_list) + return self.prog.instr_dag.add_barrier(self.rank, tb_list, barrier_id) + + +@dataclass +class Ref(ChunkRef): + prog: MSCCLPPProgram + + def __repr__(self): + return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})" + + def _end(self): + return self.index + self.size + + def _get_chunk(self, index): + return self.prog.buffers[self.rank][self.buffer][index] + + def split(self, num): + assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts" + chunks = [None] * num + size = self.size // num + for i in range(num): + index = self.index + i * size + chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size) + return chunks + + def group(self, other): + assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}" + assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}" + if self.index < other.index: + first = self + second = other + else: + first = other + second = self + + end = max(first._end(), second._end()) + return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog) + + def _get_buffer_index(self, remote_rank, buffer, index): + if index == -1 and buffer == None: + return self.buffer, self.index + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + return buffer, self.prog.buffers[remote_rank][buffer].instance_size() + return buffer, index + + def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False): + self.prog.check_buffer_exists(dst, buffer) + assert self.rank != dst, "Cannot put to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + if use_packet: + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True) + self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) + self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) + else: + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) + return dst_chunkref + + def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + return self._put(dst, buffer, index, sendtb, chan_type) + + def put_packet( + self, + dst, + buffer=None, + index=-1, + sendtb=-1, + chan_type=ChannelType.sm, + temp_buffer=None, + temp_buffer_index=-1, + ): + chunk_ref = self + if chan_type == ChannelType.proxy: + assert temp_buffer is not None, "Need to specify a temporary buffer for proxy channels" + chunk_ref = self._copy( + self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True + ) + return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True) + + def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): + self.prog.check_buffer_exists(src, buffer) + sender = src + receiver = self.rank + assert sender != receiver, "Cannot get from the same rank" + buffer, index = self._get_buffer_index(src, buffer, index) + + src_chunkref = self.prog.get_ref(src, buffer, index, self.size) + + self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size) + self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type) + + # for signal and wait, currently we assuem the pair will use the same tb index. In future we need + # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). + # Then we can use DAG info to reduce the number of channels. + def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + sender = self.rank + receiver = dst + assert sender != receiver, "Cannot signal to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type) + + # only proxy channel need to use this function + def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.proxy): + assert chan_type == ChannelType.proxy, "Only proxy channel can use flush" + sender = self.rank + receiver = dst + assert sender != receiver, "Cannot flush to the same rank" + buffer, index = self._get_buffer_index(dst, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb) + + def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): + sender = src + receiver = self.rank + assert sender != receiver, "Cannot wait on the same rank" + buffer, index = self._get_buffer_index(src, buffer, index) + + src_chunkref = self.prog.get_ref(src, buffer, index, self.size) + self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type) + + def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, trans_to_packet=False): + self.prog.check_buffer_exists(dst, buffer) + buffer, index = self._get_buffer_index(dst, buffer, index) + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + # Check if we are copying the chunk to the same index (easy mistake when we are using inplace) + if dst_chunkref == self: + return + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + + assert self.rank == dst, "Chunk copy only supports intra-rank communication" + self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, trans_from_packet, trans_to_packet) + + return dst_chunkref + + # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) + def copy(self, dst, buffer=None, index=-1, sendtb=-1): + return self._copy(dst, buffer, index, sendtb) + + def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): + return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False) + + def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False): + dst = self.rank + src = other_chunkref.rank + + self.prog.apply_reduce( + src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size + ) + if use_packet: + assert src == dst, "Packet reduce only supports intra-rank communication" + + if src != dst: + self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type) + else: + self.prog.instr_dag.add_reduce(src, other_chunkref, self, recvtb, use_packet) + + return self + + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm): + return self._reduce(other_chunkref, recvtb, channel_type) + + # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref + def reduce_packet(self, other_chunkref, recvtb=-1): + return self._reduce(other_chunkref, recvtb, use_packet=True) + + # """ + # Group operations. These operations are used to perform collective operations across multiple chunks. + # For now, all chunks must has the same buffer type and offset. + # """ + # Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref + def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls): + assert ( + len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls + ), "Group load reduce only supports nvls channel" + nranks_per_node = self.prog.collective.num_ranks_per_node + for other_chunkref in other_chunkrefs: + assert ( + self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node + ), "Group load reduce only supports chunks on the same node" + assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer" + assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index" + + src_chunkref = other_chunkref + self.prog.apply_reduce( + src_chunkref.rank, + src_chunkref.buffer, + src_chunkref.index, + self.rank, + self.buffer, + self.index, + self.size, + ) + self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type) + return self + + # Copies the chunk(s) referenced by this chunkref onto other_chunkrefs + def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls): + for dst in dsts: + self.prog.check_buffer_exists(dst, buffer) + assert index == -1 or self.index == index, "Group store only supports chunks with the same index" + assert chan_type == ChannelType.nvls, "Group store only supports nvls channel" + + other_chunkrefs = [] + nrank_per_node = self.prog.collective.num_ranks_per_node + for dst in dsts: + # Direct linked + buffer, index = self._get_buffer_index(dst, buffer, index) + + assert self.buffer == buffer, "Group store only supports chunks with the same buffer" + assert ( + self.rank // nrank_per_node == dst // nrank_per_node + ), "Group store only supports chunks on the same node" + + dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) + other_chunkrefs.append(dst_chunkref) + # add new op here + self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type) + + def get_origin_index(self, index=0): + return self._get_chunk(index + self.index).origin_index + + def get_origin_rank(self, index=0): + return self._get_chunk(index + self.index).origin_rank + + def get_dst_index(self, index=0): + return self._get_chunk(index + self.index).dst_index + + def get_dst_rank(self, index=0): + return self._get_chunk(index + self.index).dst_rank + + def print_chunk_info(self, index=0): + print(self._get_chunk(index + self.index)) + + +def chunk(rank, buffer, index, size=1) -> Ref: + if _curr().buffers[rank][buffer][index] is None: + return None + return _curr().get_ref(rank, buffer, index, size) + + +def rank(rank) -> RankRef: + return _curr().get_rank_ref(rank) + + +def Check(): + return _curr().check() diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py new file mode 100644 index 000000000..0daa82d06 --- /dev/null +++ b/python/mscclpp/language/rank.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, field +from typing import Dict + + +class BarrierInfo: + def __init__(self, tb_list): + self.tb_list = tb_list + + def __eq__(self, other): + return self.tb_list == other.tb_list + + def __hash__(self): + return hash(tuple(self.tb_list)) + + +@dataclass +class Rank: + rank_id: int + current_max_barrier_id: int = 0 + current_barriers: Dict[BarrierInfo, int] = field(default_factory=dict) + + def get_barrier_id(self, tb_list): + barrier_info = BarrierInfo(tb_list) + if barrier_info in self.current_barriers: + return self.current_barriers[barrier_info] + else: + self.current_barriers[barrier_info] = self.current_max_barrier_id + barrier_id = self.current_max_barrier_id + self.current_max_barrier_id += 1 + return barrier_id diff --git a/python/mscclpp/language/types.py b/python/mscclpp/language/types.py new file mode 100644 index 000000000..f6202ccfe --- /dev/null +++ b/python/mscclpp/language/types.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, field +from enum import Enum +from typing import Union, List + +from mscclpp.language.buffer import Buffer + + +@dataclass +class Gpu: + rank: int + threadblocks: list = field(default_factory=list) + + # From ncclize + precopies: list = field(default_factory=list) + postcopies: list = field(default_factory=list) + inputs: dict = field(default_factory=dict) + outputs: dict = field(default_factory=dict) + input_chunks: int = 0 + output_chunks: int = 0 + scratch_chunks: int = 0 + scratch: dict = field(default_factory=dict) + channels: dict = field(default_factory=dict) + + def scratch_size(self): + return max((idx for addr, idx in self.scratch.items()), default=-1) + 1 + + +@dataclass +class Program: + name: str + collective: str + inplace: bool + protocol: str + gpus: List[Gpu] = field(default_factory=list) + num_chunk_groups: int = 1 + num_threads_per_block: int = 1024 + use_double_scratch_buffer: bool = False + min_message_size: int = 0 + max_message_size: int = 2**64 - 1 + + +@dataclass +class Threadblock: + channel: int = -1 + send: int = -1 + recv: int = -1 + ops: list = field(default_factory=list) + rbid: int = -1 # threadblock id of the receiver + id: int = -1 + channels: list = field(default_factory=list) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +class ReplicationPolicy(Enum): + # this means each instance deal with the different chunk + # Chunk A, Chunk B -> Chunk A0, Chunk B0, Chunk A1, Chunk B1 + duplicated = "duplicated" + # this means each instance deal with the different chunk in interleaved way + # Chunk A, Chunk B -> Chunk A0, Chunk A1, Chunk B0, Chunk B1 + interleaved = "interleaved" + # this means pack multi instrances to deal with the same chunk and share the channels + packed = "packed" + + def __str__(self): + return self.value + + +class Instruction(Enum): + start = "start" + nop = "nop" + read_reduce_copy = "rrc" + read_reduce_copy_send = "rrcs" + reduce_send = "rs" + copy = "copy" + reduce = "reduce" + copy_packet = "cpkt" + transform_to_packet = "tpkt" + reduce_send_packet = "rspkt" + reduce_packet = "rpkt" + put = "put" + put_packet = "ppkt" + put_with_signal = "pws" + put_with_signal_and_flush = "pwsf" + get = "get" + wait = "wait" + signal = "signal" + flush = "flush" + barrier = "barrier" + group_store = "gstore" + group_load_reduce = "glre" + group_load_reduce_store = "glres" + + def __str__(self): + return self.value + + +@dataclass +class ChunkRef: + rank: int + buffer: Buffer + index: int + size: int + + def __hash__(self): + return hash((self.rank, self.buffer, self.index, self.size)) + + +class ChannelType(Enum): + proxy = "proxy" + sm = "sm" + none = "none" + nvls = "nvls" + + def __str__(self): + return self.value + + +@dataclass(frozen=True) +class Channel: + srcBuffer: Buffer + dstBuffer: Buffer + type: ChannelType + connected_to: Union[int, List[int]] + + def __hash__(self): + # Ensure connected_to is converted to a tuple if it's a list + connected_to_hashable = tuple(self.connected_to) if isinstance(self.connected_to, list) else self.connected_to + return hash((self.srcBuffer, self.dstBuffer, self.type, connected_to_hashable)) + + +@dataclass +class Op: + inst: Instruction + rank: int + src: ChunkRef + dst: ChunkRef + depends: list = field(default_factory=list) + step: int = -1 # Step in the TB + tb: int = -1 # TB this op is assigned to + prev: list = field(default_factory=list) # List of instructions that happen before + next: list = field(default_factory=list) # List of instructions that happen after + channel: int = -1 + channel_type: ChannelType = ChannelType.none + srcs: list = field(default_factory=list) + dsts: list = field(default_factory=list) + extra: dict = field(default_factory=dict) + + def cnt(self): + if self.src: + if self.dst: + assert self.src.size == self.dst.size + return self.src.size + elif self.dst: + return self.dst.size + else: + return 0 + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + def __repr__(self): + return f"Op({self.inst}, {self.rank}, {self.src}, {self.dst}, step:{self.step}, tb:{self.tb})" diff --git a/python/mscclpp/language/utils.py b/python/mscclpp/language/utils.py new file mode 100644 index 000000000..c97af881b --- /dev/null +++ b/python/mscclpp/language/utils.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.types import Op + + +def remove_op(op: Op): + for p in op.prev: + p.next.remove(op) + p.next += op.next + p.next = list(set(p.next)) + + for n in op.next: + n.prev.remove(op) + n.prev = op.prev.union(n.prev) + + op.next = [] + op.prev = [] + + +def merge_op(op: Op, other_op: Op): + if other_op in op.next: + op.next.remove(other_op) + other_op.prev.remove(op) + for p in other_op.prev: + p.next.remove(other_op) + p.next.append(op) + + for n in other_op.next: + n.prev.remove(other_op) + n.prev.add(op) + + op.prev = op.prev.union(other_op.prev) + op.next = list(set(op.next + other_op.next)) + + +def circular_dep_after_merge(op: Op, other_op: Op): + root = set([op, other_op]) + frontier = set(op.next) + if other_op in frontier: + frontier.remove(other_op) + frontier = list(frontier.union(other_op.next)) + while len(frontier) > 0: + current = frontier[0] + for n in current.next: + # The root node will be visited again if there is a circular dependency + if n in root: + return True + frontier.append(n) + frontier = frontier[1:] + + +def all_prevs_visited_after_merge(op: Op, other_op: Op): + """ + For case: op2.prev = [op1, op3]. op1.next = [op2]. op3.next = [op2]. And op1 and op2 are satisfied to merge. + We only apply the merge if all previous ops of op2 are visited. (op1 is the last previous op of op2). + """ + step = op.step + for prev in other_op.prev: + if prev.step > step: + return False + return True + + +def same_tb(op1: Op, op2: Op): + return op1.tb == op2.tb and op1.channel == op2.channel + + +def same_count(op1: Op, op2: Op): + return op1.cnt() == op2.cnt() + + +def same_buf_dst(op1: Op, op2: Op): + return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index + + +def same_src_dst_buffer_type(op1: Op, op2: Op): + return op1.src.buffer == op2.src.buffer and op1.dst.buffer == op2.dst.buffer + + +def buf_dst_src_match(op1: Op, op2: Op): + return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index + + +def same_buf_src(op1: Op, op2: Op): + return op1.src.buffer == op2.src.buffer and op1.src.index == op2.src.index + + +def same_chan_type(op1: Op, op2: Op): + return op1.channel_type == op2.channel_type + + +def same_tb(op1: Op, op2: Op): + return op1.tb == op2.tb diff --git a/python/test/configs/mscclpp_lang_test_config.json b/python/test/configs/mscclpp_lang_test_config.json new file mode 100644 index 000000000..f4cbaf930 --- /dev/null +++ b/python/test/configs/mscclpp_lang_test_config.json @@ -0,0 +1,34 @@ +[ + { + "filename": "allgather_barrier.py", + "args": ["8", "8"] + }, + { + "filename": "allreduce_allpairs_packet.py", + "args": ["8", "8"] + }, + { + "filename": "allreduce_allpairs_get.py", + "args": ["8", "8"] + }, + { + "filename": "allreduce_allpairs.py", + "args": ["8", "8"] + }, + { + "filename": "allreduce_ring.py", + "args": ["8", "8"] + }, + { + "filename": "send_recv_packet.py", + "args": ["2"] + }, + { + "filename": "send_recv_proxy.py", + "args": ["2"] + }, + { + "filename": "allreduce_nvls.py", + "args": ["8", "2"] + } +] diff --git a/python/test/test_generate_mscclpp_lang_result.py b/python/test/test_generate_mscclpp_lang_result.py new file mode 100644 index 000000000..036c71c5f --- /dev/null +++ b/python/test/test_generate_mscclpp_lang_result.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import json +from pathlib import Path +import subprocess + + +def run_examples(input_folder, configs, output_folder): + for config in configs: + file_name = config["filename"] + args = config["args"] + + input_file_path = Path(input_folder) / file_name + # Strip the ".py" from the filename and add ".output" + base_file_name = file_name[:-3] if file_name.endswith(".py") else file_name + base_file_name = base_file_name.replace("/", "_") + output_file_path = Path(output_folder) / f"{base_file_name}.output" + + # Construct the command to run the Python script + command = ["python3", str(input_file_path)] + args + + # Run the command and capture output + with open(output_file_path, "w") as output_file: + result = subprocess.run(command, stdout=output_file, stderr=subprocess.STDOUT, text=True) + + # Optional: Check the return code to handle errors + if result.returncode != 0: + print(f"Error running {file_name}. See {output_file_path} for details.") + + +def main(input_folder, config_path, output_folder): + with open(config_path, "r") as f: + config = json.load(f) + + Path(output_folder).mkdir(parents=True, exist_ok=True) + run_examples(input_folder, config, output_folder) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process files according to a configuration and save the results.") + parser.add_argument("input_folder", type=str, help="Path to the folder containing the input files.") + parser.add_argument("config", type=str, help="Path to the configuration file.") + parser.add_argument("output_folder", type=str, help="Path to the folder where the processed files will be saved.") + args = parser.parse_args() + main(args.input_folder, args.config, args.output_folder)